From 1e5d85e7b42763219b9746157b4a899c1798cf93 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Fri, 17 Jun 2022 12:55:22 -0700 Subject: [PATCH] cirq-core target gatesets: accept additional gates to keep untouched. (#5445) Builds on top of https://github.com/quantumlib/Cirq/pull/5429 The internal gate representation for `additional_gates` is updated to match `cirq.Gateset`: * Equality check uses GateFamily representation. Otherwise different representations of the gate will not be considered equal. * JSON uses GateFamily representation. * repr uses the representation passed in via the constructor. `assert_optimizes` in `cz_gateset_test.py` is updated to take in an optional `additional_gates` instead, to exercise CZTargetGateset constructor's defaulting logic. No tests are added since `additional_gates` need to be set in existing tests after `ignore_errors` is set to False. @tanujkhattar --- .../cirq/contrib/paulistring/optimize_test.py | 7 +- .../json_test_data/CZTargetGateset.json | 38 +++++++++++ .../json_test_data/CZTargetGateset.repr | 18 ++++- .../SqrtIswapTargetGateset.json | 39 +++++++++++ .../SqrtIswapTargetGateset.repr | 25 ++++++- .../target_gatesets/cz_gateset.py | 59 ++++++++++++++--- .../target_gatesets/cz_gateset_test.py | 64 ++++++++++++++---- .../target_gatesets/sqrt_iswap_gateset.py | 51 ++++++++++++-- .../sqrt_iswap_gateset_test.py | 66 ++++++++++++++----- 9 files changed, 314 insertions(+), 53 deletions(-) diff --git a/cirq-core/cirq/contrib/paulistring/optimize_test.py b/cirq-core/cirq/contrib/paulistring/optimize_test.py index c5b14faf68c..235683852f9 100644 --- a/cirq-core/cirq/contrib/paulistring/optimize_test.py +++ b/cirq-core/cirq/contrib/paulistring/optimize_test.py @@ -50,14 +50,15 @@ def test_optimize(): cirq.testing.assert_allclose_up_to_global_phase(c_orig.unitary(), c_opt.unitary(), atol=1e-6) + # TODO(#5546) Fix '[Z]^1' (should be 'Z') cirq.testing.assert_has_diagram( c_opt, """ -0: ───X^0.5────────────@──────────────────────────────────────── +0: ───X^0.5────────────@────────────────────────────────────────────── │ -1: ───@───────X^-0.5───@───@────────────────@───Z^-0.5────────── +1: ───@───────X^-0.5───@───@────────────────@───Z^-0.5──────────────── │ │ │ -2: ───@────────────────────@───[X]^(-7/8)───@───[X]^-0.25───Z─── +2: ───@────────────────────@───[X]^(-7/8)───@───[X]^-0.25───[Z]^(1)─── """, ) diff --git a/cirq-core/cirq/protocols/json_test_data/CZTargetGateset.json b/cirq-core/cirq/protocols/json_test_data/CZTargetGateset.json index be77ff8faea..23f3615b41d 100644 --- a/cirq-core/cirq/protocols/json_test_data/CZTargetGateset.json +++ b/cirq-core/cirq/protocols/json_test_data/CZTargetGateset.json @@ -8,5 +8,43 @@ "cirq_type": "CZTargetGateset", "atol": 1e-08, "allow_partial_czs": true + }, + { + "cirq_type": "CZTargetGateset", + "atol": 1e-06, + "allow_partial_czs": true, + "additional_gates": [ + { + "cirq_type": "GateFamily", + "gate": { + "cirq_type": "ISwapPowGate", + "exponent": 0.5, + "global_shift": 0.0 + }, + "name": "Instance GateFamily: ISWAP**0.5", + "description": "Accepts `cirq.Gate` instances `g` s.t. `g == ISWAP**0.5`", + "ignore_global_phase": true, + "tags_to_accept": [], + "tags_to_ignore": [] + }, + { + "cirq_type": "GateFamily", + "gate": "XPowGate", + "name": "Type GateFamily: cirq.ops.common_gates.XPowGate", + "description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.XPowGate)`", + "ignore_global_phase": true, + "tags_to_accept": [], + "tags_to_ignore": [] + }, + { + "cirq_type": "GateFamily", + "gate": "ZPowGate", + "name": "Type GateFamily: cirq.ops.common_gates.ZPowGate", + "description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.ZPowGate)`", + "ignore_global_phase": true, + "tags_to_accept": [], + "tags_to_ignore": [] + } + ] } ] \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/CZTargetGateset.repr b/cirq-core/cirq/protocols/json_test_data/CZTargetGateset.repr index 0821e1394d6..3d0df0caa6f 100644 --- a/cirq-core/cirq/protocols/json_test_data/CZTargetGateset.repr +++ b/cirq-core/cirq/protocols/json_test_data/CZTargetGateset.repr @@ -1,4 +1,18 @@ [ - cirq.CZTargetGateset(atol=1e-06, allow_partial_czs=False), - cirq.CZTargetGateset(atol=1e-08, allow_partial_czs=True), + cirq.CZTargetGateset(atol=1e-06, allow_partial_czs=False, additional_gates=[]), + cirq.CZTargetGateset(atol=1e-08, allow_partial_czs=True, additional_gates=[]), + cirq.CZTargetGateset( + atol=1e-06, + allow_partial_czs=True, + additional_gates=[ + (cirq.ISWAP**0.5), + cirq.ops.common_gates.XPowGate, + cirq.GateFamily( + gate=cirq.ops.common_gates.ZPowGate, + ignore_global_phase=True, + tags_to_accept=frozenset(), + tags_to_ignore=frozenset(), + ), + ], + ), ] diff --git a/cirq-core/cirq/protocols/json_test_data/SqrtIswapTargetGateset.json b/cirq-core/cirq/protocols/json_test_data/SqrtIswapTargetGateset.json index d52e44147c4..31b16b6b82c 100644 --- a/cirq-core/cirq/protocols/json_test_data/SqrtIswapTargetGateset.json +++ b/cirq-core/cirq/protocols/json_test_data/SqrtIswapTargetGateset.json @@ -16,5 +16,44 @@ "atol": 1e-06, "required_sqrt_iswap_count": 2, "use_sqrt_iswap_inv": true + }, + { + "cirq_type": "SqrtIswapTargetGateset", + "atol": 1e-08, + "required_sqrt_iswap_count": null, + "use_sqrt_iswap_inv": false, + "additional_gates": [ + { + "cirq_type": "GateFamily", + "gate": { + "cirq_type": "CZPowGate", + "exponent": 1.0, + "global_shift": 0.0 + }, + "name": "Instance GateFamily: CZ", + "description": "Accepts `cirq.Gate` instances `g` s.t. `g == CZ`", + "ignore_global_phase": true, + "tags_to_accept": [], + "tags_to_ignore": [] + }, + { + "cirq_type": "GateFamily", + "gate": "XPowGate", + "name": "Type GateFamily: cirq.ops.common_gates.XPowGate", + "description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.XPowGate)`", + "ignore_global_phase": true, + "tags_to_accept": [], + "tags_to_ignore": [] + }, + { + "cirq_type": "GateFamily", + "gate": "ZPowGate", + "name": "Type GateFamily: cirq.ops.common_gates.ZPowGate", + "description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.ZPowGate)`", + "ignore_global_phase": true, + "tags_to_accept": [], + "tags_to_ignore": [] + } + ] } ] \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/SqrtIswapTargetGateset.repr b/cirq-core/cirq/protocols/json_test_data/SqrtIswapTargetGateset.repr index 06c6315ca94..09527eb9516 100644 --- a/cirq-core/cirq/protocols/json_test_data/SqrtIswapTargetGateset.repr +++ b/cirq-core/cirq/protocols/json_test_data/SqrtIswapTargetGateset.repr @@ -1,7 +1,26 @@ [ cirq.SqrtIswapTargetGateset( - atol=1e-08, required_sqrt_iswap_count=None, use_sqrt_iswap_inv=False + atol=1e-08, required_sqrt_iswap_count=None, use_sqrt_iswap_inv=False, additional_gates=[] + ), + cirq.SqrtIswapTargetGateset( + atol=1e-08, required_sqrt_iswap_count=1, use_sqrt_iswap_inv=False, additional_gates=[] + ), + cirq.SqrtIswapTargetGateset( + atol=1e-06, required_sqrt_iswap_count=2, use_sqrt_iswap_inv=True, additional_gates=[] + ), + cirq.SqrtIswapTargetGateset( + atol=1e-08, + required_sqrt_iswap_count=None, + use_sqrt_iswap_inv=False, + additional_gates=[ + cirq.CZ, + cirq.ops.common_gates.XPowGate, + cirq.GateFamily( + gate=cirq.ops.common_gates.ZPowGate, + ignore_global_phase=True, + tags_to_accept=frozenset(), + tags_to_ignore=frozenset(), + ), + ], ), - cirq.SqrtIswapTargetGateset(atol=1e-08, required_sqrt_iswap_count=1, use_sqrt_iswap_inv=False), - cirq.SqrtIswapTargetGateset(atol=1e-06, required_sqrt_iswap_count=2, use_sqrt_iswap_inv=True), ] diff --git a/cirq-core/cirq/transformers/target_gatesets/cz_gateset.py b/cirq-core/cirq/transformers/target_gatesets/cz_gateset.py index fa83f481120..af00f242a9d 100644 --- a/cirq-core/cirq/transformers/target_gatesets/cz_gateset.py +++ b/cirq-core/cirq/transformers/target_gatesets/cz_gateset.py @@ -14,7 +14,7 @@ """Target gateset used for compiling circuits to CZ + 1-q rotations + measurement gates.""" -from typing import Any, Dict, TYPE_CHECKING +from typing import Any, Dict, Sequence, Type, Union, TYPE_CHECKING from cirq import ops, protocols from cirq.transformers.analytical_decompositions import two_qubit_to_cz @@ -25,23 +25,53 @@ class CZTargetGateset(compilation_target_gateset.TwoQubitCompilationTargetGateset): - """Target gateset containing CZ + single qubit rotations + Measurement gates.""" + """Target gateset accepting CZ + single qubit rotations + measurement gates. - def __init__(self, *, atol: float = 1e-8, allow_partial_czs: bool = False) -> None: + By default, `cirq.CZTargetGateset` will accept and compile unknown gates to + the following universal target gateset: + - `cirq.CZ` / `cirq.CZPowGate`: The two qubit entangling gate. + - `cirq.PhasedXZGate`: Single qubit rotations. + - `cirq.MeasurementGate`: Measurements. + - `cirq.GlobalPhaseGate`: Global phase. + + Optionally, users can also specify additional gates / gate families which should + be accepted by this gateset via the `additional_gates` argument. + + When compiling a circuit, any unknown gate, i.e. a gate which is not accepted by + this gateset, will be compiled to the default gateset (i.e. `cirq.CZ`/`cirq.CZPowGate`, + `cirq.PhasedXZGate`, `cirq.MeasurementGate`). + """ + + def __init__( + self, + *, + atol: float = 1e-8, + allow_partial_czs: bool = False, + additional_gates: Sequence[Union[Type['cirq.Gate'], 'cirq.Gate', 'cirq.GateFamily']] = (), + ) -> None: """Initializes CZTargetGateset Args: atol: A limit on the amount of absolute error introduced by the decomposition. allow_partial_czs: If set, all powers of the form `cirq.CZ**t`, and not just `cirq.CZ`, are part of this gateset. + additional_gates: Sequence of additional gates / gate families which should also + be "accepted" by this gateset. Defaults to `cirq.GlobalPhaseGate`. """ super().__init__( ops.CZPowGate if allow_partial_czs else ops.CZ, ops.MeasurementGate, - ops.AnyUnitaryGateFamily(1), + ops.PhasedXZGate, ops.GlobalPhaseGate, + *additional_gates, name='CZPowTargetGateset' if allow_partial_czs else 'CZTargetGateset', ) + self.additional_gates = tuple( + g if isinstance(g, ops.GateFamily) else ops.GateFamily(gate=g) for g in additional_gates + ) + self._additional_gates_repr_str = ", ".join( + [ops.gateset._gate_str(g, repr) for g in additional_gates] + ) self.atol = atol self.allow_partial_czs = allow_partial_czs @@ -57,14 +87,25 @@ def _decompose_two_qubit_operation(self, op: 'cirq.Operation', _) -> 'cirq.OP_TR ) def __repr__(self) -> str: - return f'cirq.CZTargetGateset(atol={self.atol}, allow_partial_czs={self.allow_partial_czs})' + return ( + f'cirq.CZTargetGateset(' + f'atol={self.atol}, ' + f'allow_partial_czs={self.allow_partial_czs}, ' + f'additional_gates=[{self._additional_gates_repr_str}]' + f')' + ) def _value_equality_values_(self) -> Any: - return self.atol, self.allow_partial_czs + return self.atol, self.allow_partial_czs, frozenset(self.additional_gates) def _json_dict_(self) -> Dict[str, Any]: - return {'atol': self.atol, 'allow_partial_czs': self.allow_partial_czs} + d: Dict[str, Any] = {'atol': self.atol, 'allow_partial_czs': self.allow_partial_czs} + if self.additional_gates: + d['additional_gates'] = list(self.additional_gates) + return d @classmethod - def _from_json_dict_(cls, atol, allow_partial_czs, **kwargs): - return cls(atol=atol, allow_partial_czs=allow_partial_czs) + def _from_json_dict_(cls, atol, allow_partial_czs, additional_gates=(), **kwargs): + return cls( + atol=atol, allow_partial_czs=allow_partial_czs, additional_gates=additional_gates + ) diff --git a/cirq-core/cirq/transformers/target_gatesets/cz_gateset_test.py b/cirq-core/cirq/transformers/target_gatesets/cz_gateset_test.py index 40952cd895b..00f31414d2a 100644 --- a/cirq-core/cirq/transformers/target_gatesets/cz_gateset_test.py +++ b/cirq-core/cirq/transformers/target_gatesets/cz_gateset_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Sequence, Type import pytest import cirq import sympy @@ -25,9 +26,18 @@ def all_gates_of_type(m: cirq.Moment, g: cirq.Gateset): return True -def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit): +def assert_optimizes( + before: cirq.Circuit, + expected: cirq.Circuit, + additional_gates: Optional[Sequence[Type[cirq.Gate]]] = None, +): + if additional_gates is None: + gateset = cirq.CZTargetGateset() + else: + gateset = cirq.CZTargetGateset(additional_gates=additional_gates) + cirq.testing.assert_same_circuits( - cirq.optimize_for_target_gateset(before, gateset=cirq.CZTargetGateset()), expected + cirq.optimize_for_target_gateset(before, gateset=gateset, ignore_failures=False), expected ) @@ -37,7 +47,7 @@ def assert_optimization_not_broken(circuit: cirq.Circuit): circuit, c_new, atol=1e-6 ) c_new = cirq.optimize_for_target_gateset( - circuit, gateset=cirq.CZTargetGateset(allow_partial_czs=True) + circuit, gateset=cirq.CZTargetGateset(allow_partial_czs=True), ignore_failures=False ) cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( circuit, c_new, atol=1e-6 @@ -57,7 +67,11 @@ def test_convert_to_cz_preserving_moment_structure(): cirq.X(q[2]).with_classical_controls("m"), cirq.CZ(*q[3:]).with_classical_controls("m"), ) - c_new = cirq.optimize_for_target_gateset(c_orig, gateset=cirq.CZTargetGateset()) + # Classically controlled operations are not part of the gateset, so failures should be ignored + # during compilation. + c_new = cirq.optimize_for_target_gateset( + c_orig, gateset=cirq.CZTargetGateset(), ignore_failures=True + ) assert c_orig[-2:] == c_new[-2:] c_orig, c_new = c_orig[:-2], c_new[:-2] @@ -65,7 +79,7 @@ def test_convert_to_cz_preserving_moment_structure(): cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(c_orig, c_new, atol=1e-6) assert all( ( - all_gates_of_type(m, cirq.Gateset(cirq.AnyUnitaryGateFamily(1))) + all_gates_of_type(m, cirq.Gateset(cirq.PhasedXZGate)) or all_gates_of_type(m, cirq.Gateset(cirq.CZ)) ) for m in c_new @@ -77,7 +91,7 @@ def test_convert_to_cz_preserving_moment_structure(): cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(c_orig, c_new, atol=1e-6) assert all( ( - all_gates_of_type(m, cirq.Gateset(cirq.AnyUnitaryGateFamily(1))) + all_gates_of_type(m, cirq.Gateset(cirq.PhasedXZGate)) or all_gates_of_type(m, cirq.Gateset(cirq.CZPowGate)) ) for m in c_new @@ -109,6 +123,7 @@ def test_ignores_czs_separated_by_parameterized(): cirq.Moment(cirq.CZ(a, b)), ] ), + additional_gates=[cirq.ZPowGate], ) @@ -153,7 +168,7 @@ def test_optimizes_single_iswap(): a, b = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.ISWAP(a, b)) assert_optimization_not_broken(c) - c = cirq.optimize_for_target_gateset(c, gateset=cirq.CZTargetGateset()) + c = cirq.optimize_for_target_gateset(c, gateset=cirq.CZTargetGateset(), ignore_failures=False) assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 2 @@ -161,7 +176,7 @@ def test_optimizes_tagged_partial_cz(): a, b = cirq.LineQubit.range(2) c = cirq.Circuit((cirq.CZ**0.5)(a, b).with_tags('mytag')) assert_optimization_not_broken(c) - c = cirq.optimize_for_target_gateset(c, gateset=cirq.CZTargetGateset()) + c = cirq.optimize_for_target_gateset(c, gateset=cirq.CZTargetGateset(), ignore_failures=False) assert ( len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 2 ), 'It should take 2 CZ gates to decompose a CZ**0.5 gate' @@ -185,7 +200,9 @@ def test_not_decompose_czs(): ), ) def test_decompose_partial_czs(circuit): - circuit = cirq.optimize_for_target_gateset(circuit, gateset=cirq.CZTargetGateset()) + circuit = cirq.optimize_for_target_gateset( + circuit, gateset=cirq.CZTargetGateset(), ignore_failures=False + ) cz_gates = [ op.gate for op in circuit.all_operations() @@ -201,7 +218,7 @@ def test_not_decompose_partial_czs(): circuit = cirq.Circuit( cirq.CZPowGate(exponent=0.1, global_shift=-0.5)(*cirq.LineQubit.range(2)) ) - cirq.optimize_for_target_gateset(circuit, gateset=cirq.CZTargetGateset()) + cirq.optimize_for_target_gateset(circuit, gateset=cirq.CZTargetGateset(), ignore_failures=False) cz_gates = [ op.gate for op in circuit.all_operations() @@ -240,7 +257,7 @@ def _decompose_(self, qubits): a, b = cirq.LineQubit.range(2) c = cirq.Circuit(OtherXX()(a, b), OtherOtherXX()(a, b)) - c = cirq.optimize_for_target_gateset(c, gateset=cirq.CZTargetGateset()) + c = cirq.optimize_for_target_gateset(c, gateset=cirq.CZTargetGateset(), ignore_failures=False) assert len(c) == 0 @@ -260,7 +277,9 @@ def _decompose_(self, qubits): expected = cirq.Circuit( cirq.X(q0), cirq.Y(q0) ** 0.5, cirq.CZ(q0, q1), cirq.X(q1), cirq.Y(q1) ** 0.5 ) - c_new = cirq.optimize_for_target_gateset(circuit, gateset=cirq.CZTargetGateset()) + c_new = cirq.optimize_for_target_gateset( + circuit, gateset=cirq.CZTargetGateset(), ignore_failures=False + ) cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( c_new, expected, atol=1e-6 @@ -281,3 +300,24 @@ class UnsupportedDummy(cirq.testing.TwoQubitGate): _ = cirq.optimize_for_target_gateset( circuit, gateset=cirq.CZTargetGateset(), ignore_failures=False ) + + +@pytest.mark.parametrize( + 'gateset', + [ + cirq.CZTargetGateset(), + cirq.CZTargetGateset( + atol=1e-6, + allow_partial_czs=True, + additional_gates=[ + cirq.SQRT_ISWAP, + cirq.XPowGate, + cirq.YPowGate, + cirq.GateFamily(cirq.ZPowGate, tags_to_accept=['test_tag']), + ], + ), + cirq.CZTargetGateset(additional_gates=()), + ], +) +def test_repr(gateset): + cirq.testing.assert_equivalent_repr(gateset) diff --git a/cirq-core/cirq/transformers/target_gatesets/sqrt_iswap_gateset.py b/cirq-core/cirq/transformers/target_gatesets/sqrt_iswap_gateset.py index f6fa3f93c90..5d26665d599 100644 --- a/cirq-core/cirq/transformers/target_gatesets/sqrt_iswap_gateset.py +++ b/cirq-core/cirq/transformers/target_gatesets/sqrt_iswap_gateset.py @@ -14,7 +14,7 @@ """Target gateset used for compiling circuits to √iSWAP + 1-q rotations + measurement gates.""" -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, Dict, Optional, Sequence, Type, Union, TYPE_CHECKING from cirq import ops, protocols from cirq.protocols.decompose_protocol import DecomposeResult @@ -26,7 +26,22 @@ class SqrtIswapTargetGateset(compilation_target_gateset.TwoQubitCompilationTargetGateset): - """Target gateset containing √iSWAP + single qubit rotations + Measurement gates.""" + """Target gateset accepting √iSWAP + single qubit rotations + measurement gates. + + By default, `cirq.SqrtIswapTargetGateset` will accept and compile unknown gates to + the following universal target gateset: + - `cirq.SQRT_ISWAP` / `cirq.SQRT_ISWAP_INV`: The two qubit entangling gate. + - `cirq.PhasedXZGate`: Single qubit rotations. + - `cirq.MeasurementGate`: Measurements. + - `cirq.GlobalPhaseGate`: Global phase. + + Optionally, users can also specify additional gates / gate families which should + be accepted by this gateset via the `additional_gates` argument. + + When compiling a circuit, any unknown gate, i.e. a gate which is not accepted by + this gateset, will be compiled to the default gateset (i.e. `cirq.SQRT_ISWAP`/ + `cirq.cirq.SQRT_ISWAP_INV`, `cirq.PhasedXZGate`, `cirq.MeasurementGate`). + """ def __init__( self, @@ -34,6 +49,7 @@ def __init__( atol: float = 1e-8, required_sqrt_iswap_count: Optional[int] = None, use_sqrt_iswap_inv: bool = False, + additional_gates: Sequence[Union[Type['cirq.Gate'], 'cirq.Gate', 'cirq.GateFamily']] = (), ): """Initializes `cirq.SqrtIswapTargetGateset` @@ -45,6 +61,8 @@ def __init__( synthesis of the operation requires more. use_sqrt_iswap_inv: If True, `cirq.SQRT_ISWAP_INV` is used as part of the gateset, instead of `cirq.SQRT_ISWAP`. + additional_gates: Sequence of additional gates / gate families which should also + be "accepted" by this gateset. Defaults to `cirq.GlobalPhaseGate`. Raises: ValueError: If `required_sqrt_iswap_count` is specified and is not 0, 1, 2, or 3. @@ -54,10 +72,17 @@ def __init__( super().__init__( ops.SQRT_ISWAP_INV if use_sqrt_iswap_inv else ops.SQRT_ISWAP, ops.MeasurementGate, - ops.AnyUnitaryGateFamily(1), + ops.PhasedXZGate, ops.GlobalPhaseGate, + *additional_gates, name='SqrtIswapInvTargetGateset' if use_sqrt_iswap_inv else 'SqrtIswapTargetGateset', ) + self.additional_gates = tuple( + g if isinstance(g, ops.GateFamily) else ops.GateFamily(gate=g) for g in additional_gates + ) + self._additional_gates_repr_str = ", ".join( + [ops.gateset._gate_str(g, repr) for g in additional_gates] + ) self.atol = atol self.required_sqrt_iswap_count = required_sqrt_iswap_count self.use_sqrt_iswap_inv = use_sqrt_iswap_inv @@ -85,24 +110,36 @@ def __repr__(self) -> str: f'cirq.SqrtIswapTargetGateset(' f'atol={self.atol}, ' f'required_sqrt_iswap_count={self.required_sqrt_iswap_count}, ' - f'use_sqrt_iswap_inv={self.use_sqrt_iswap_inv}' + f'use_sqrt_iswap_inv={self.use_sqrt_iswap_inv}, ' + f'additional_gates=[{self._additional_gates_repr_str}]' f')' ) def _value_equality_values_(self) -> Any: - return (self.atol, self.required_sqrt_iswap_count, self.use_sqrt_iswap_inv) + return ( + self.atol, + self.required_sqrt_iswap_count, + self.use_sqrt_iswap_inv, + frozenset(self.additional_gates), + ) def _json_dict_(self) -> Dict[str, Any]: - return { + d: Dict[str, Any] = { 'atol': self.atol, 'required_sqrt_iswap_count': self.required_sqrt_iswap_count, 'use_sqrt_iswap_inv': self.use_sqrt_iswap_inv, } + if self.additional_gates: + d['additional_gates'] = list(self.additional_gates) + return d @classmethod - def _from_json_dict_(cls, atol, required_sqrt_iswap_count, use_sqrt_iswap_inv, **kwargs): + def _from_json_dict_( + cls, atol, required_sqrt_iswap_count, use_sqrt_iswap_inv, additional_gates=(), **kwargs + ): return cls( atol=atol, required_sqrt_iswap_count=required_sqrt_iswap_count, use_sqrt_iswap_inv=use_sqrt_iswap_inv, + additional_gates=additional_gates, ) diff --git a/cirq-core/cirq/transformers/target_gatesets/sqrt_iswap_gateset_test.py b/cirq-core/cirq/transformers/target_gatesets/sqrt_iswap_gateset_test.py index 9f2af0fd3ba..b2cf7ed4076 100644 --- a/cirq-core/cirq/transformers/target_gatesets/sqrt_iswap_gateset_test.py +++ b/cirq-core/cirq/transformers/target_gatesets/sqrt_iswap_gateset_test.py @@ -29,7 +29,9 @@ def all_gates_of_type(m: cirq.Moment, g: cirq.Gateset): def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit, **kwargs): cirq.testing.assert_same_circuits( - cirq.optimize_for_target_gateset(before, gateset=cirq.SqrtIswapTargetGateset(**kwargs)), + cirq.optimize_for_target_gateset( + before, gateset=cirq.SqrtIswapTargetGateset(**kwargs), ignore_failures=False + ), expected, ) @@ -40,6 +42,7 @@ def assert_optimization_not_broken( c_new = cirq.optimize_for_target_gateset( circuit, gateset=cirq.SqrtIswapTargetGateset(required_sqrt_iswap_count=required_sqrt_iswap_count), + ignore_failures=False, ) cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( circuit, c_new, atol=1e-6 @@ -49,6 +52,7 @@ def assert_optimization_not_broken( gateset=cirq.SqrtIswapTargetGateset( use_sqrt_iswap_inv=True, required_sqrt_iswap_count=required_sqrt_iswap_count ), + ignore_failures=False, ) cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( circuit, c_new, atol=1e-6 @@ -68,8 +72,11 @@ def test_convert_to_sqrt_iswap_preserving_moment_structure(): cirq.X(q[2]).with_classical_controls("m"), cirq.CZ(*q[3:]).with_classical_controls("m"), ) - - c_new = cirq.optimize_for_target_gateset(c_orig, gateset=cirq.SqrtIswapTargetGateset()) + # Classically controlled operations are not part of the gateset, so failures should be ignored + # during compilation. + c_new = cirq.optimize_for_target_gateset( + c_orig, gateset=cirq.SqrtIswapTargetGateset(), ignore_failures=True + ) assert c_orig[-2:] == c_new[-2:] c_orig, c_new = c_orig[:-2], c_new[:-2] @@ -77,7 +84,7 @@ def test_convert_to_sqrt_iswap_preserving_moment_structure(): cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(c_orig, c_new, atol=1e-6) assert all( ( - all_gates_of_type(m, cirq.Gateset(cirq.AnyUnitaryGateFamily(1))) + all_gates_of_type(m, cirq.Gateset(cirq.PhasedXZGate)) or all_gates_of_type(m, cirq.Gateset(cirq.SQRT_ISWAP)) ) for m in c_new @@ -89,7 +96,7 @@ def test_convert_to_sqrt_iswap_preserving_moment_structure(): cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(c_orig, c_new, atol=1e-6) assert all( ( - all_gates_of_type(m, cirq.Gateset(cirq.AnyUnitaryGateFamily(1))) + all_gates_of_type(m, cirq.Gateset(cirq.PhasedXZGate)) or all_gates_of_type(m, cirq.Gateset(cirq.SQRT_ISWAP_INV)) ) for m in c_new @@ -111,7 +118,12 @@ def test_two_qubit_gates_with_symbols(gate: cirq.Gate, use_sqrt_iswap_inv: bool) c_orig = cirq.Circuit(gate(*cirq.LineQubit.range(2))) c_new = cirq.optimize_for_target_gateset( - c_orig, gateset=cirq.SqrtIswapTargetGateset(use_sqrt_iswap_inv=use_sqrt_iswap_inv) + c_orig, + gateset=cirq.SqrtIswapTargetGateset( + use_sqrt_iswap_inv=use_sqrt_iswap_inv, + additional_gates=[cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate], + ), + ignore_failures=False, ) # Check that `c_new` only contains sqrt iswap as the 2q entangling gate. @@ -145,6 +157,7 @@ def test_sqrt_iswap_gateset_eq(): eq.add_equality_group( cirq.SqrtIswapTargetGateset(atol=1e-6, required_sqrt_iswap_count=3, use_sqrt_iswap_inv=True) ) + eq.add_equality_group(cirq.SqrtIswapTargetGateset(additional_gates=[cirq.XPowGate])) @pytest.mark.parametrize( @@ -152,8 +165,17 @@ def test_sqrt_iswap_gateset_eq(): [ cirq.SqrtIswapTargetGateset(), cirq.SqrtIswapTargetGateset( - atol=1e-6, required_sqrt_iswap_count=2, use_sqrt_iswap_inv=True + atol=1e-6, + required_sqrt_iswap_count=2, + use_sqrt_iswap_inv=True, + additional_gates=[ + cirq.CZ, + cirq.XPowGate, + cirq.YPowGate, + cirq.GateFamily(cirq.ZPowGate, tags_to_accept=['test_tag']), + ], ), + cirq.SqrtIswapTargetGateset(additional_gates=()), ], ) def test_sqrt_iswap_gateset_repr(gateset): @@ -282,7 +304,9 @@ def test_optimizes_single_iswap(): a, b = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.ISWAP(a, b)) assert_optimization_not_broken(c) - c = cirq.optimize_for_target_gateset(c, gateset=cirq.SqrtIswapTargetGateset()) + c = cirq.optimize_for_target_gateset( + c, gateset=cirq.SqrtIswapTargetGateset(), ignore_failures=False + ) assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 2 @@ -290,7 +314,9 @@ def test_optimizes_single_inv_sqrt_iswap(): a, b = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.SQRT_ISWAP_INV(a, b)) assert_optimization_not_broken(c) - c = cirq.optimize_for_target_gateset(c, gateset=cirq.SqrtIswapTargetGateset()) + c = cirq.optimize_for_target_gateset( + c, gateset=cirq.SqrtIswapTargetGateset(), ignore_failures=False + ) assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 1 @@ -299,7 +325,7 @@ def test_optimizes_single_iswap_require0(): c = cirq.Circuit(cirq.CNOT(a, b), cirq.CNOT(a, b)) # Minimum 0 sqrt-iSWAP assert_optimization_not_broken(c, required_sqrt_iswap_count=0) c = cirq.optimize_for_target_gateset( - c, gateset=cirq.SqrtIswapTargetGateset(required_sqrt_iswap_count=0) + c, gateset=cirq.SqrtIswapTargetGateset(required_sqrt_iswap_count=0), ignore_failures=False ) assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 0 @@ -309,7 +335,9 @@ def test_optimizes_single_iswap_require0_raises(): c = cirq.Circuit(cirq.CNOT(a, b)) # Minimum 2 sqrt-iSWAP with pytest.raises(ValueError, match='cannot be decomposed into exactly 0 sqrt-iSWAP gates'): _ = cirq.optimize_for_target_gateset( - c, gateset=cirq.SqrtIswapTargetGateset(required_sqrt_iswap_count=0) + c, + gateset=cirq.SqrtIswapTargetGateset(required_sqrt_iswap_count=0), + ignore_failures=False, ) @@ -318,7 +346,7 @@ def test_optimizes_single_iswap_require1(): c = cirq.Circuit(cirq.SQRT_ISWAP_INV(a, b)) # Minimum 1 sqrt-iSWAP assert_optimization_not_broken(c, required_sqrt_iswap_count=1) c = cirq.optimize_for_target_gateset( - c, gateset=cirq.SqrtIswapTargetGateset(required_sqrt_iswap_count=1) + c, gateset=cirq.SqrtIswapTargetGateset(required_sqrt_iswap_count=1), ignore_failures=False ) assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 1 @@ -328,7 +356,9 @@ def test_optimizes_single_iswap_require1_raises(): c = cirq.Circuit(cirq.CNOT(a, b)) # Minimum 2 sqrt-iSWAP with pytest.raises(ValueError, match='cannot be decomposed into exactly 1 sqrt-iSWAP gates'): c = cirq.optimize_for_target_gateset( - c, gateset=cirq.SqrtIswapTargetGateset(required_sqrt_iswap_count=1) + c, + gateset=cirq.SqrtIswapTargetGateset(required_sqrt_iswap_count=1), + ignore_failures=False, ) @@ -337,7 +367,7 @@ def test_optimizes_single_iswap_require2(): c = cirq.Circuit(cirq.SQRT_ISWAP_INV(a, b)) # Minimum 1 sqrt-iSWAP but 2 possible assert_optimization_not_broken(c, required_sqrt_iswap_count=2) c = cirq.optimize_for_target_gateset( - c, gateset=cirq.SqrtIswapTargetGateset(required_sqrt_iswap_count=2) + c, gateset=cirq.SqrtIswapTargetGateset(required_sqrt_iswap_count=2), ignore_failures=False ) assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 2 @@ -347,7 +377,9 @@ def test_optimizes_single_iswap_require2_raises(): c = cirq.Circuit(cirq.SWAP(a, b)) # Minimum 3 sqrt-iSWAP with pytest.raises(ValueError, match='cannot be decomposed into exactly 2 sqrt-iSWAP gates'): c = cirq.optimize_for_target_gateset( - c, gateset=cirq.SqrtIswapTargetGateset(required_sqrt_iswap_count=2) + c, + gateset=cirq.SqrtIswapTargetGateset(required_sqrt_iswap_count=2), + ignore_failures=False, ) @@ -356,7 +388,7 @@ def test_optimizes_single_iswap_require3(): c = cirq.Circuit(cirq.ISWAP(a, b)) # Minimum 2 sqrt-iSWAP but 3 possible assert_optimization_not_broken(c, required_sqrt_iswap_count=3) c = cirq.optimize_for_target_gateset( - c, gateset=cirq.SqrtIswapTargetGateset(required_sqrt_iswap_count=3) + c, gateset=cirq.SqrtIswapTargetGateset(required_sqrt_iswap_count=3), ignore_failures=False ) assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 3 @@ -366,6 +398,6 @@ def test_optimizes_single_inv_sqrt_iswap_require3(): c = cirq.Circuit(cirq.SQRT_ISWAP_INV(a, b)) assert_optimization_not_broken(c, required_sqrt_iswap_count=3) c = cirq.optimize_for_target_gateset( - c, gateset=cirq.SqrtIswapTargetGateset(required_sqrt_iswap_count=3) + c, gateset=cirq.SqrtIswapTargetGateset(required_sqrt_iswap_count=3), ignore_failures=False ) assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 3