Skip to content

Commit

Permalink
Deprecate Gateset.accept_global_phase_op (#5239)
Browse files Browse the repository at this point in the history
Deprecates Gateset.accept_global_phase_op

**Breaking Change:** Changes the default value of `Gateset.accept_global_phase_op` from `True` to `False`. I can't think of any way to remove this parameter without eventually needing this breaking change. Currently all gatesets that are created allow global phase gates if they don't specify `accept_global_phase_op=False` explicitly. But the end goal is only to allow global phase gates if they're included in the `gates` list. So at some point in the transition the default behavior needs to break, and I can't think of a way of doing that via deprecation. Therefore I think we may as well do it now via this breaking change.

Note that even though it's breaking, it isn't breaking in a bad way. Users who are adding global phase gates to things that suddenly don't accept them will just see an error that the gate is not in the gateset, and then go add it. It's much safer than breaking in the other direction in which we silently start allowing new gate types.

Closes #4741

@tanujkhattar
  • Loading branch information
daxfohl authored Apr 18, 2022
1 parent 0f995ee commit aa303dc
Show file tree
Hide file tree
Showing 25 changed files with 215 additions and 77 deletions.
1 change: 0 additions & 1 deletion cirq-core/cirq/ion/ion_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def get_ion_gateset() -> ops.Gateset:
ops.ZPowGate,
ops.PhasedXPowGate,
unroll_circuit_op=False,
accept_global_phase_op=False,
)


Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/ion/ion_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def num_qubits(self):


def test_can_add_operation_into_moment_device_deprecated():
with cirq.testing.assert_deprecated('can_add_operation_into_moment', deadline='v0.15', count=5):
with cirq.testing.assert_deprecated('can_add_operation_into_moment', deadline='v0.15', count=6):
d = ion_device(3)
q0 = cirq.LineQubit(0)
q1 = cirq.LineQubit(1)
Expand Down Expand Up @@ -218,10 +218,10 @@ def test_at():


def test_qubit_set_deprecated():
with cirq.testing.assert_deprecated('qubit_set', deadline='v0.15'):
with cirq.testing.assert_deprecated('qubit_set', deadline='v0.15', count=2):
assert ion_device(3).qubit_set() == frozenset(cirq.LineQubit.range(3))


def test_qid_pairs_deprecated():
with cirq.testing.assert_deprecated('device.metadata', deadline='v0.15', count=1):
with cirq.testing.assert_deprecated('device.metadata', deadline='v0.15', count=2):
assert len(ion_device(10).qid_pairs()) == 45
3 changes: 0 additions & 3 deletions cirq-core/cirq/neutral_atoms/neutral_atom_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def neutral_atom_gateset(max_parallel_z=None, max_parallel_xy=None):
ops.MeasurementGate,
ops.IdentityGate,
unroll_circuit_op=False,
accept_global_phase_op=False,
)


Expand Down Expand Up @@ -100,15 +99,13 @@ def __init__(
ops.ParallelGateFamily(ops.YPowGate),
ops.ParallelGateFamily(ops.PhasedXPowGate),
unroll_circuit_op=False,
accept_global_phase_op=False,
)
self.controlled_gateset = ops.Gateset(
ops.AnyIntegerPowerGateFamily(ops.CNotPowGate),
ops.AnyIntegerPowerGateFamily(ops.CCNotPowGate),
ops.AnyIntegerPowerGateFamily(ops.CZPowGate),
ops.AnyIntegerPowerGateFamily(ops.CCZPowGate),
unroll_circuit_op=False,
accept_global_phase_op=False,
)
self.gateset = neutral_atom_gateset(max_parallel_z, max_parallel_xy)
for q in qubits:
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/neutral_atoms/neutral_atom_devices_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def test_validate_moment_errors():


def test_can_add_operation_into_moment_coverage_deprecated():
with cirq.testing.assert_deprecated('can_add_operation_into_moment', deadline='v0.15', count=3):
with cirq.testing.assert_deprecated('can_add_operation_into_moment', deadline='v0.15', count=4):
d = square_device(2, 2)
q00 = cirq.GridQubit(0, 0)
q01 = cirq.GridQubit(0, 1)
Expand Down Expand Up @@ -298,5 +298,5 @@ def test_repr_pretty():


def test_qubit_set_deprecated():
with cirq.testing.assert_deprecated('qubit_set', deadline='v0.15'):
with cirq.testing.assert_deprecated('qubit_set', deadline='v0.15', count=2):
assert square_device(2, 2).qubit_set() == frozenset(cirq.GridQubit.square(2, 0, 0))
99 changes: 70 additions & 29 deletions cirq-core/cirq/ops/gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

"""Functionality for grouping and validating Cirq Gates"""

import warnings
from typing import Any, Callable, cast, Dict, FrozenSet, List, Optional, Type, TYPE_CHECKING, Union

from cirq import _compat, protocols, value
from cirq.ops import global_phase_op, op_tree, raw_types
from cirq import protocols, value

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -201,12 +203,20 @@ class Gateset:
validation purposes.
"""

@_compat.deprecated_parameter(
deadline='v0.16',
fix='To accept global phase gates, add cirq.GlobalPhaseGate to the list of *gates passed '
'to the constructor. By default, global phase gates will not be accepted by the '
'gateset',
parameter_desc='accept_global_phase_op',
match=lambda args, kwargs: 'accept_global_phase_op' in kwargs,
)
def __init__(
self,
*gates: Union[Type[raw_types.Gate], raw_types.Gate, GateFamily],
name: Optional[str] = None,
unroll_circuit_op: bool = True,
accept_global_phase_op: bool = True,
accept_global_phase_op: Optional[bool] = None,
) -> None:
"""Init Gateset.
Expand All @@ -225,17 +235,36 @@ def __init__(
name: (Optional) Name for the Gateset. Useful for description.
unroll_circuit_op: If True, `cirq.CircuitOperation` is recursively
validated by validating the underlying `cirq.Circuit`.
accept_global_phase_op: If True, `cirq.GlobalPhaseOperation` is accepted.
accept_global_phase_op: If True, a `GateFamily` accepting
`cirq.GlobalPhaseGate` will be included. If None,
`cirq.GlobalPhaseGate` will not modify the input `*gates`.
If False, `cirq.GlobalPhaseGate` will be removed from the
gates. This parameter defaults to None (a breaking change from
v0.14.1) and will be removed in v0.16.
"""
self._name = name
self._unroll_circuit_op = unroll_circuit_op
self._accept_global_phase_op = accept_global_phase_op
if accept_global_phase_op:
gates = gates + (global_phase_op.GlobalPhaseGate,)
self._instance_gate_families: Dict[raw_types.Gate, GateFamily] = {}
self._type_gate_families: Dict[Type[raw_types.Gate], GateFamily] = {}
self._gates_repr_str = ", ".join([_gate_str(g, repr) for g in gates])
unique_gate_list: List[GateFamily] = list(
dict.fromkeys(g if isinstance(g, GateFamily) else GateFamily(gate=g) for g in gates)
)
if accept_global_phase_op is False:
unique_gate_list = [
g for g in unique_gate_list if g.gate is not global_phase_op.GlobalPhaseGate
]
elif accept_global_phase_op is None:
if not any(g.gate is global_phase_op.GlobalPhaseGate for g in unique_gate_list):
warnings.warn(
'v0.14.1 is the last release `cirq.GlobalPhaseGate` is included by default. If'
' you were relying on this behavior, you can include a `cirq.GlobalPhaseGate`'
' in your `*gates`. If not, then you can ignore this warning. It will be'
' removed in v0.16'
)

for g in unique_gate_list:
if type(g) == GateFamily:
if isinstance(g.gate, raw_types.Gate):
Expand All @@ -253,6 +282,12 @@ def name(self) -> Optional[str]:
def gates(self) -> FrozenSet[GateFamily]:
return self._gates

@_compat.deprecated_parameter(
deadline='v0.16',
fix='Add a global phase gate to the Gateset',
parameter_desc='accept_global_phase_op',
match=lambda args, kwargs: 'accept_global_phase_op' in kwargs,
)
def with_params(
self,
*,
Expand All @@ -268,7 +303,12 @@ def with_params(
name: New name for the Gateset.
unroll_circuit_op: If True, new Gateset will recursively validate
`cirq.CircuitOperation` by validating the underlying `cirq.Circuit`.
accept_global_phase_op: If True, new Gateset will accept `cirq.GlobalPhaseOperation`.
accept_global_phase_op: If True, a `GateFamily` accepting
`cirq.GlobalPhaseGate` will be included. If None,
`cirq.GlobalPhaseGate` will not modify the input `*gates`.
If False, `cirq.GlobalPhaseGate` will be removed from the
gates. This parameter defaults to None (a breaking change from
v0.14.1) and will be removed in v0.16.
Returns:
`self` if all new values are None or identical to the values of current Gateset.
Expand All @@ -280,19 +320,23 @@ def val_if_none(var: Any, val: Any) -> Any:

name = val_if_none(name, self._name)
unroll_circuit_op = val_if_none(unroll_circuit_op, self._unroll_circuit_op)
accept_global_phase_op = val_if_none(accept_global_phase_op, self._accept_global_phase_op)
global_phase_family = GateFamily(gate=global_phase_op.GlobalPhaseGate)
if (
name == self._name
and unroll_circuit_op == self._unroll_circuit_op
and accept_global_phase_op == self._accept_global_phase_op
and (
accept_global_phase_op is True
and global_phase_family in self.gates
or accept_global_phase_op is False
and not any(g.gate is global_phase_op.GlobalPhaseGate for g in self.gates)
or accept_global_phase_op is None
)
):
return self
return Gateset(
*self.gates,
name=name,
unroll_circuit_op=cast(bool, unroll_circuit_op),
accept_global_phase_op=cast(bool, accept_global_phase_op),
)
gates = self.gates
if accept_global_phase_op:
gates = gates.union({global_phase_family})
return Gateset(*gates, name=name, unroll_circuit_op=cast(bool, unroll_circuit_op))

def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool:
"""Check for containment of a given Gate/Operation in this Gateset.
Expand Down Expand Up @@ -326,9 +370,6 @@ def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool
g = item if isinstance(item, raw_types.Gate) else item.gate
assert g is not None, f'`item`: {item} must be a gate or have a valid `item.gate`'

if isinstance(g, global_phase_op.GlobalPhaseGate):
return self._accept_global_phase_op

if g in self._instance_gate_families:
assert item in self._instance_gate_families[g], (
f"{item} instance matches {self._instance_gate_families[g]} but "
Expand Down Expand Up @@ -394,16 +435,15 @@ def _validate_operation(self, op: raw_types.Operation) -> bool:
return False

def _value_equality_values_(self) -> Any:
return (self.gates, self.name, self._unroll_circuit_op, self._accept_global_phase_op)
return (self.gates, self.name, self._unroll_circuit_op)

def __repr__(self) -> str:
name_str = f'name = "{self.name}", ' if self.name is not None else ''
return (
f'cirq.Gateset('
f'{self._gates_repr_str}, '
f'{name_str}'
f'unroll_circuit_op = {self._unroll_circuit_op},'
f'accept_global_phase_op = {self._accept_global_phase_op})'
f'unroll_circuit_op = {self._unroll_circuit_op})'
)

def __str__(self) -> str:
Expand All @@ -417,16 +457,17 @@ def _json_dict_(self) -> Dict[str, Any]:
'gates': self._unique_gate_list,
'name': self.name,
'unroll_circuit_op': self._unroll_circuit_op,
'accept_global_phase_op': self._accept_global_phase_op,
}

@classmethod
def _from_json_dict_(
cls, gates, name, unroll_circuit_op, accept_global_phase_op, **kwargs
) -> 'Gateset':
return cls(
*gates,
name=name,
unroll_circuit_op=unroll_circuit_op,
accept_global_phase_op=accept_global_phase_op,
)
def _from_json_dict_(cls, gates, name, unroll_circuit_op, **kwargs) -> 'Gateset':
if 'accept_global_phase_op' in kwargs:
accept_global_phase_op = kwargs['accept_global_phase_op']
global_phase_family = GateFamily(gate=global_phase_op.GlobalPhaseGate)
if accept_global_phase_op is True:
gates.append(global_phase_family)
elif accept_global_phase_op is False:
gates = [
family for family in gates if family.gate is not global_phase_op.GlobalPhaseGate
]
return cls(*gates, name=name, unroll_circuit_op=unroll_circuit_op)
60 changes: 35 additions & 25 deletions cirq-core/cirq/ops/gateset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,19 +257,21 @@ def assert_validate_and_contains_consistent(gateset, op_tree, result):
assert gateset.validate(item) is result

op_tree = [*get_ops(use_circuit_op, use_global_phase)]
assert_validate_and_contains_consistent(
gateset.with_params(
unroll_circuit_op=use_circuit_op, accept_global_phase_op=use_global_phase
),
op_tree,
True,
)
if use_circuit_op or use_global_phase:
with cirq.testing.assert_deprecated('global phase', deadline='v0.16', count=None):
assert_validate_and_contains_consistent(
gateset.with_params(unroll_circuit_op=False, accept_global_phase_op=False),
gateset.with_params(
unroll_circuit_op=use_circuit_op, accept_global_phase_op=use_global_phase
),
op_tree,
False,
True,
)
if use_circuit_op or use_global_phase:
with cirq.testing.assert_deprecated('global phase', deadline='v0.16', count=2):
assert_validate_and_contains_consistent(
gateset.with_params(unroll_circuit_op=False, accept_global_phase_op=False),
op_tree,
False,
)


def test_gateset_validate_circuit_op_negative_reps():
Expand All @@ -281,31 +283,39 @@ def test_gateset_validate_circuit_op_negative_reps():

def test_with_params():
assert gateset.with_params() is gateset
assert (
gateset.with_params(
name=gateset.name,
unroll_circuit_op=gateset._unroll_circuit_op,
accept_global_phase_op=gateset._accept_global_phase_op,
with cirq.testing.assert_deprecated('global phase', deadline='v0.16'):
assert (
gateset.with_params(
name=gateset.name,
unroll_circuit_op=gateset._unroll_circuit_op,
accept_global_phase_op=None,
)
is gateset
)
with cirq.testing.assert_deprecated('global phase', deadline='v0.16', count=2):
gateset_with_params = gateset.with_params(
name='new name', unroll_circuit_op=False, accept_global_phase_op=False
)
is gateset
)
gateset_with_params = gateset.with_params(
name='new name', unroll_circuit_op=False, accept_global_phase_op=False
)
assert gateset_with_params.name == 'new name'
assert gateset_with_params._unroll_circuit_op is False
assert gateset_with_params._accept_global_phase_op is False


def test_gateset_eq():
eq = cirq.testing.EqualsTester()
eq.add_equality_group(cirq.Gateset(CustomX))
eq.add_equality_group(cirq.Gateset(CustomX**3))
eq.add_equality_group(cirq.Gateset(CustomX, name='Custom Gateset'))
with cirq.testing.assert_deprecated('global phase', deadline='v0.16'):
eq.add_equality_group(
cirq.Gateset(CustomX, name='Custom Gateset'),
cirq.Gateset(
CustomX, cirq.GlobalPhaseGate, name='Custom Gateset', accept_global_phase_op=False
),
)
eq.add_equality_group(cirq.Gateset(CustomX, name='Custom Gateset', unroll_circuit_op=False))
eq.add_equality_group(
cirq.Gateset(CustomX, name='Custom Gateset', accept_global_phase_op=False)
)
with cirq.testing.assert_deprecated('global phase', deadline='v0.16'):
eq.add_equality_group(
cirq.Gateset(CustomX, name='Custom Gateset', accept_global_phase_op=True)
)
eq.add_equality_group(
cirq.Gateset(
cirq.GateFamily(CustomX, name='custom_name', description='custom_description'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self, ignore_failures: bool = False, allow_partial_czs: bool = Fals
ops.CZPowGate if allow_partial_czs else ops.CZ,
ops.MeasurementGate,
ops.AnyUnitaryGateFamily(1),
ops.GlobalPhaseGate,
)

def _decompose_two_qubit_unitaries(self, op: ops.Operation) -> ops.OP_TREE:
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/optimizers/merge_interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ def __init__(
self.allow_partial_czs = allow_partial_czs
self.gateset = ops.Gateset(
ops.CZPowGate if allow_partial_czs else ops.CZ,
ops.GlobalPhaseGate,
unroll_circuit_op=False,
accept_global_phase_op=True,
)

def _may_keep_old_op(self, old_op: 'cirq.Operation') -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def __init__(
self.use_sqrt_iswap_inv = use_sqrt_iswap_inv
self.gateset = ops.Gateset(
ops.SQRT_ISWAP_INV if use_sqrt_iswap_inv else ops.SQRT_ISWAP,
ops.GlobalPhaseGate,
unroll_circuit_op=False,
accept_global_phase_op=True,
)

def _may_keep_old_op(self, old_op: 'cirq.Operation') -> bool:
Expand Down
6 changes: 2 additions & 4 deletions cirq-core/cirq/protocols/json_test_data/Gateset.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
}
],
"name": null,
"unroll_circuit_op": true,
"accept_global_phase_op": true
"unroll_circuit_op": true
},
{
"cirq_type": "Gateset",
Expand Down Expand Up @@ -56,7 +55,6 @@
}
],
"name": "Custom Name",
"unroll_circuit_op": false,
"accept_global_phase_op": false
"unroll_circuit_op": false
}
]
Loading

0 comments on commit aa303dc

Please sign in to comment.