Skip to content

Commit

Permalink
cirq-core target gatesets: accept additional gates to keep untouched. (
Browse files Browse the repository at this point in the history
…#5445)

Builds on top of #5429

The internal gate representation for `additional_gates` is updated to match `cirq.Gateset`:
* Equality check uses GateFamily representation. Otherwise different representations of the gate will not be considered equal.
* JSON uses GateFamily representation.
* repr uses the representation passed in via the constructor.

`assert_optimizes` in `cz_gateset_test.py` is updated to take in an optional `additional_gates` instead, to exercise CZTargetGateset constructor's defaulting logic.

No tests are added since `additional_gates` need to be set in existing tests after `ignore_errors` is set to False.

@tanujkhattar
  • Loading branch information
verult authored Jun 17, 2022
1 parent c58dd4d commit 1e5d85e
Show file tree
Hide file tree
Showing 9 changed files with 314 additions and 53 deletions.
7 changes: 4 additions & 3 deletions cirq-core/cirq/contrib/paulistring/optimize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,15 @@ def test_optimize():

cirq.testing.assert_allclose_up_to_global_phase(c_orig.unitary(), c_opt.unitary(), atol=1e-6)

# TODO(#5546) Fix '[Z]^1' (should be 'Z')
cirq.testing.assert_has_diagram(
c_opt,
"""
0: ───X^0.5────────────@────────────────────────────────────────
0: ───X^0.5────────────@──────────────────────────────────────────────
1: ───@───────X^-0.5───@───@────────────────@───Z^-0.5──────────
1: ───@───────X^-0.5───@───@────────────────@───Z^-0.5────────────────
│ │ │
2: ───@────────────────────@───[X]^(-7/8)───@───[X]^-0.25───Z───
2: ───@────────────────────@───[X]^(-7/8)───@───[X]^-0.25───[Z]^(1)───
""",
)

Expand Down
38 changes: 38 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/CZTargetGateset.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,43 @@
"cirq_type": "CZTargetGateset",
"atol": 1e-08,
"allow_partial_czs": true
},
{
"cirq_type": "CZTargetGateset",
"atol": 1e-06,
"allow_partial_czs": true,
"additional_gates": [
{
"cirq_type": "GateFamily",
"gate": {
"cirq_type": "ISwapPowGate",
"exponent": 0.5,
"global_shift": 0.0
},
"name": "Instance GateFamily: ISWAP**0.5",
"description": "Accepts `cirq.Gate` instances `g` s.t. `g == ISWAP**0.5`",
"ignore_global_phase": true,
"tags_to_accept": [],
"tags_to_ignore": []
},
{
"cirq_type": "GateFamily",
"gate": "XPowGate",
"name": "Type GateFamily: cirq.ops.common_gates.XPowGate",
"description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.XPowGate)`",
"ignore_global_phase": true,
"tags_to_accept": [],
"tags_to_ignore": []
},
{
"cirq_type": "GateFamily",
"gate": "ZPowGate",
"name": "Type GateFamily: cirq.ops.common_gates.ZPowGate",
"description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.ZPowGate)`",
"ignore_global_phase": true,
"tags_to_accept": [],
"tags_to_ignore": []
}
]
}
]
18 changes: 16 additions & 2 deletions cirq-core/cirq/protocols/json_test_data/CZTargetGateset.repr
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
[
cirq.CZTargetGateset(atol=1e-06, allow_partial_czs=False),
cirq.CZTargetGateset(atol=1e-08, allow_partial_czs=True),
cirq.CZTargetGateset(atol=1e-06, allow_partial_czs=False, additional_gates=[]),
cirq.CZTargetGateset(atol=1e-08, allow_partial_czs=True, additional_gates=[]),
cirq.CZTargetGateset(
atol=1e-06,
allow_partial_czs=True,
additional_gates=[
(cirq.ISWAP**0.5),
cirq.ops.common_gates.XPowGate,
cirq.GateFamily(
gate=cirq.ops.common_gates.ZPowGate,
ignore_global_phase=True,
tags_to_accept=frozenset(),
tags_to_ignore=frozenset(),
),
],
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,44 @@
"atol": 1e-06,
"required_sqrt_iswap_count": 2,
"use_sqrt_iswap_inv": true
},
{
"cirq_type": "SqrtIswapTargetGateset",
"atol": 1e-08,
"required_sqrt_iswap_count": null,
"use_sqrt_iswap_inv": false,
"additional_gates": [
{
"cirq_type": "GateFamily",
"gate": {
"cirq_type": "CZPowGate",
"exponent": 1.0,
"global_shift": 0.0
},
"name": "Instance GateFamily: CZ",
"description": "Accepts `cirq.Gate` instances `g` s.t. `g == CZ`",
"ignore_global_phase": true,
"tags_to_accept": [],
"tags_to_ignore": []
},
{
"cirq_type": "GateFamily",
"gate": "XPowGate",
"name": "Type GateFamily: cirq.ops.common_gates.XPowGate",
"description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.XPowGate)`",
"ignore_global_phase": true,
"tags_to_accept": [],
"tags_to_ignore": []
},
{
"cirq_type": "GateFamily",
"gate": "ZPowGate",
"name": "Type GateFamily: cirq.ops.common_gates.ZPowGate",
"description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.ZPowGate)`",
"ignore_global_phase": true,
"tags_to_accept": [],
"tags_to_ignore": []
}
]
}
]
Original file line number Diff line number Diff line change
@@ -1,7 +1,26 @@
[
cirq.SqrtIswapTargetGateset(
atol=1e-08, required_sqrt_iswap_count=None, use_sqrt_iswap_inv=False
atol=1e-08, required_sqrt_iswap_count=None, use_sqrt_iswap_inv=False, additional_gates=[]
),
cirq.SqrtIswapTargetGateset(
atol=1e-08, required_sqrt_iswap_count=1, use_sqrt_iswap_inv=False, additional_gates=[]
),
cirq.SqrtIswapTargetGateset(
atol=1e-06, required_sqrt_iswap_count=2, use_sqrt_iswap_inv=True, additional_gates=[]
),
cirq.SqrtIswapTargetGateset(
atol=1e-08,
required_sqrt_iswap_count=None,
use_sqrt_iswap_inv=False,
additional_gates=[
cirq.CZ,
cirq.ops.common_gates.XPowGate,
cirq.GateFamily(
gate=cirq.ops.common_gates.ZPowGate,
ignore_global_phase=True,
tags_to_accept=frozenset(),
tags_to_ignore=frozenset(),
),
],
),
cirq.SqrtIswapTargetGateset(atol=1e-08, required_sqrt_iswap_count=1, use_sqrt_iswap_inv=False),
cirq.SqrtIswapTargetGateset(atol=1e-06, required_sqrt_iswap_count=2, use_sqrt_iswap_inv=True),
]
59 changes: 50 additions & 9 deletions cirq-core/cirq/transformers/target_gatesets/cz_gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Target gateset used for compiling circuits to CZ + 1-q rotations + measurement gates."""

from typing import Any, Dict, TYPE_CHECKING
from typing import Any, Dict, Sequence, Type, Union, TYPE_CHECKING

from cirq import ops, protocols
from cirq.transformers.analytical_decompositions import two_qubit_to_cz
Expand All @@ -25,23 +25,53 @@


class CZTargetGateset(compilation_target_gateset.TwoQubitCompilationTargetGateset):
"""Target gateset containing CZ + single qubit rotations + Measurement gates."""
"""Target gateset accepting CZ + single qubit rotations + measurement gates.
def __init__(self, *, atol: float = 1e-8, allow_partial_czs: bool = False) -> None:
By default, `cirq.CZTargetGateset` will accept and compile unknown gates to
the following universal target gateset:
- `cirq.CZ` / `cirq.CZPowGate`: The two qubit entangling gate.
- `cirq.PhasedXZGate`: Single qubit rotations.
- `cirq.MeasurementGate`: Measurements.
- `cirq.GlobalPhaseGate`: Global phase.
Optionally, users can also specify additional gates / gate families which should
be accepted by this gateset via the `additional_gates` argument.
When compiling a circuit, any unknown gate, i.e. a gate which is not accepted by
this gateset, will be compiled to the default gateset (i.e. `cirq.CZ`/`cirq.CZPowGate`,
`cirq.PhasedXZGate`, `cirq.MeasurementGate`).
"""

def __init__(
self,
*,
atol: float = 1e-8,
allow_partial_czs: bool = False,
additional_gates: Sequence[Union[Type['cirq.Gate'], 'cirq.Gate', 'cirq.GateFamily']] = (),
) -> None:
"""Initializes CZTargetGateset
Args:
atol: A limit on the amount of absolute error introduced by the decomposition.
allow_partial_czs: If set, all powers of the form `cirq.CZ**t`, and not just
`cirq.CZ`, are part of this gateset.
additional_gates: Sequence of additional gates / gate families which should also
be "accepted" by this gateset. Defaults to `cirq.GlobalPhaseGate`.
"""
super().__init__(
ops.CZPowGate if allow_partial_czs else ops.CZ,
ops.MeasurementGate,
ops.AnyUnitaryGateFamily(1),
ops.PhasedXZGate,
ops.GlobalPhaseGate,
*additional_gates,
name='CZPowTargetGateset' if allow_partial_czs else 'CZTargetGateset',
)
self.additional_gates = tuple(
g if isinstance(g, ops.GateFamily) else ops.GateFamily(gate=g) for g in additional_gates
)
self._additional_gates_repr_str = ", ".join(
[ops.gateset._gate_str(g, repr) for g in additional_gates]
)
self.atol = atol
self.allow_partial_czs = allow_partial_czs

Expand All @@ -57,14 +87,25 @@ def _decompose_two_qubit_operation(self, op: 'cirq.Operation', _) -> 'cirq.OP_TR
)

def __repr__(self) -> str:
return f'cirq.CZTargetGateset(atol={self.atol}, allow_partial_czs={self.allow_partial_czs})'
return (
f'cirq.CZTargetGateset('
f'atol={self.atol}, '
f'allow_partial_czs={self.allow_partial_czs}, '
f'additional_gates=[{self._additional_gates_repr_str}]'
f')'
)

def _value_equality_values_(self) -> Any:
return self.atol, self.allow_partial_czs
return self.atol, self.allow_partial_czs, frozenset(self.additional_gates)

def _json_dict_(self) -> Dict[str, Any]:
return {'atol': self.atol, 'allow_partial_czs': self.allow_partial_czs}
d: Dict[str, Any] = {'atol': self.atol, 'allow_partial_czs': self.allow_partial_czs}
if self.additional_gates:
d['additional_gates'] = list(self.additional_gates)
return d

@classmethod
def _from_json_dict_(cls, atol, allow_partial_czs, **kwargs):
return cls(atol=atol, allow_partial_czs=allow_partial_czs)
def _from_json_dict_(cls, atol, allow_partial_czs, additional_gates=(), **kwargs):
return cls(
atol=atol, allow_partial_czs=allow_partial_czs, additional_gates=additional_gates
)
Loading

0 comments on commit 1e5d85e

Please sign in to comment.