Skip to content

Commit

Permalink
Support serialization of GateFamilies (#4715)
Browse files Browse the repository at this point in the history
Exactly what it says: this PR uses the new type-serialization behavior to allow GateFamilies to be serialized.

Additional work is still required for Gateset serialization - that behavior is not included in this PR.
95-martin-orion authored Dec 1, 2021

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 7259925 commit 9053a27
Showing 18 changed files with 240 additions and 5 deletions.
4 changes: 4 additions & 0 deletions cirq-core/cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
@@ -51,7 +51,9 @@ def _parallel_gate_op(gate, qubits):
import sympy

return {
'AnyIntegerPowerGateFamily': cirq.AnyIntegerPowerGateFamily,
'AmplitudeDampingChannel': cirq.AmplitudeDampingChannel,
'AnyUnitaryGateFamily': cirq.AnyUnitaryGateFamily,
'AsymmetricDepolarizingChannel': cirq.AsymmetricDepolarizingChannel,
'BitFlipChannel': cirq.BitFlipChannel,
'BitstringAccumulator': cirq.work.BitstringAccumulator,
@@ -81,6 +83,7 @@ def _parallel_gate_op(gate, qubits):
'MutableDensePauliString': cirq.MutableDensePauliString,
'MutablePauliString': cirq.MutablePauliString,
'ObservableMeasuredResult': cirq.work.ObservableMeasuredResult,
'GateFamily': cirq.GateFamily,
'GateOperation': cirq.GateOperation,
'GeneralizedAmplitudeDampingChannel': cirq.GeneralizedAmplitudeDampingChannel,
'GlobalPhaseOperation': cirq.GlobalPhaseOperation,
@@ -115,6 +118,7 @@ def _parallel_gate_op(gate, qubits):
'_PauliZ': cirq.ops.pauli_gates._PauliZ,
'ParamResolver': cirq.ParamResolver,
'ParallelGate': cirq.ParallelGate,
'ParallelGateFamily': cirq.ParallelGateFamily,
'PauliMeasurementGate': cirq.PauliMeasurementGate,
'PauliString': cirq.PauliString,
'PhaseDampingChannel': cirq.PhaseDampingChannel,
32 changes: 32 additions & 0 deletions cirq-core/cirq/ops/common_gate_families.py
Original file line number Diff line number Diff line change
@@ -53,6 +53,13 @@ def __repr__(self) -> str:
def _value_equality_values_(self) -> Any:
return self._num_qubits

def _json_dict_(self):
return {'num_qubits': self._num_qubits}

@classmethod
def _from_json_dict_(cls, num_qubits, **kwargs):
return cls(num_qubits)


class AnyIntegerPowerGateFamily(gateset.GateFamily):
"""GateFamily which accepts instances of a given `cirq.EigenGate`, raised to integer power."""
@@ -87,6 +94,15 @@ def __repr__(self) -> str:
def _value_equality_values_(self) -> Any:
return self.gate

def _json_dict_(self):
return {'gate': self._gate_json()}

@classmethod
def _from_json_dict_(cls, gate, **kwargs):
if isinstance(gate, str):
gate = protocols.cirq_type_from_json(gate)
return cls(gate)


class ParallelGateFamily(gateset.GateFamily):
"""GateFamily which accepts instances of `cirq.ParallelGate` and it's sub_gate.
@@ -175,3 +191,19 @@ def __repr__(self) -> str:
def _value_equality_values_(self) -> Any:
# `isinstance` is used to ensure the a gate type and gate instance is not compared.
return super()._value_equality_values_() + (self._max_parallel_allowed,)

def _json_dict_(self):
return {
'gate': self._gate_json(),
'name': self.name,
'description': self.description,
'max_parallel_allowed': self._max_parallel_allowed,
}

@classmethod
def _from_json_dict_(cls, gate, name, description, max_parallel_allowed, **kwargs):
if isinstance(gate, str):
gate = protocols.cirq_type_from_json(gate)
return cls(
gate, name=name, description=description, max_parallel_allowed=max_parallel_allowed
)
19 changes: 19 additions & 0 deletions cirq-core/cirq/ops/gateset.py
Original file line number Diff line number Diff line change
@@ -97,6 +97,9 @@ def __init__(
def _gate_str(self, gettr: Callable[[Any], str] = str) -> str:
return _gate_str(self.gate, gettr)

def _gate_json(self) -> Union[raw_types.Gate, str]:
return self.gate if not isinstance(self.gate, type) else protocols.json_cirq_type(self.gate)

def _default_name(self) -> str:
family_type = 'Instance' if isinstance(self.gate, raw_types.Gate) else 'Type'
return f'{family_type} GateFamily: {self._gate_str()}'
@@ -167,6 +170,22 @@ def _value_equality_values_(self) -> Any:
self._ignore_global_phase,
)

def _json_dict_(self):
return {
'gate': self._gate_json(),
'name': self.name,
'description': self.description,
'ignore_global_phase': self._ignore_global_phase,
}

@classmethod
def _from_json_dict_(cls, gate, name, description, ignore_global_phase, **kwargs):
if isinstance(gate, str):
gate = protocols.cirq_type_from_json(gate)
return cls(
gate, name=name, description=description, ignore_global_phase=ignore_global_phase
)


@value.value_equality()
class Gateset:
8 changes: 8 additions & 0 deletions cirq-core/cirq/ops/gateset_test.py
Original file line number Diff line number Diff line change
@@ -97,6 +97,14 @@ def test_gate_family_repr_and_str(gate, name, description):
assert g.description in str(g)


@pytest.mark.parametrize('gate', [cirq.X, cirq.XPowGate(), cirq.XPowGate])
@pytest.mark.parametrize('name, description', [(None, None), ('custom_name', 'custom_description')])
def test_gate_family_json(gate, name, description):
g = cirq.GateFamily(gate, name=name, description=description)
g_json = cirq.to_json(g)
assert cirq.read_json(json_text=g_json) == g


def test_gate_family_eq():
eq = cirq.testing.EqualsTester()
eq.add_equality_group(cirq.GateFamily(CustomX))
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"cirq_type": "AnyIntegerPowerGateFamily",
"gate": "XPowGate"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.AnyIntegerPowerGateFamily(cirq.ops.common_gates.XPowGate)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"cirq_type": "AnyUnitaryGateFamily",
"num_qubits": 2
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.AnyUnitaryGateFamily(num_qubits = 2)
20 changes: 20 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/GateFamily.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[
{
"cirq_type": "GateFamily",
"gate": "XPowGate",
"name": "Type GateFamily: cirq.ops.common_gates.XPowGate",
"description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.XPowGate)`",
"ignore_global_phase": true
},
{
"cirq_type": "GateFamily",
"gate": {
"cirq_type": "_PauliX",
"exponent": 1.0,
"global_shift": 0.0
},
"name": "XFamily",
"description": "Just the X gate.",
"ignore_global_phase": false
}
]
4 changes: 4 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/GateFamily.repr
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[
cirq.GateFamily(gate=cirq.ops.common_gates.XPowGate, ignore_global_phase=True),
cirq.GateFamily(gate=cirq.X, name="XFamily", description="Just the X gate.", ignore_global_phase=False)
]
20 changes: 20 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/ParallelGateFamily.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[
{
"cirq_type": "ParallelGateFamily",
"gate": "XPowGate",
"name": "INF Parallel Type GateFamily: cirq.ops.common_gates.XPowGate",
"description": "Accepts\n1. `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.XPowGate)` OR\n2. `cirq.ParallelGate` instance `g` s.t. `g.sub_gate` satisfies 1. and `cirq.num_qubits(g) <= INF` OR\n3. `cirq.Operation` instance `op` s.t. `op.gate` satisfies 1. or 2.",
"max_parallel_allowed": null
},
{
"cirq_type": "ParallelGateFamily",
"gate": {
"cirq_type": "_PauliX",
"exponent": 1.0,
"global_shift": 0.0
},
"name": "ParallelXFamily",
"description": "Up to 4 parallel X gates",
"max_parallel_allowed": 4
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[
cirq.ParallelGateFamily(gate=cirq.ops.common_gates.XPowGate, max_parallel_allowed=None),
cirq.ParallelGateFamily(
gate=cirq.X,
name="ParallelXFamily",
description=r'''Up to 4 parallel X gates''',
max_parallel_allowed=4
)
]
4 changes: 0 additions & 4 deletions cirq-core/cirq/protocols/json_test_data/spec.py
Original file line number Diff line number Diff line change
@@ -25,8 +25,6 @@
resolver_cache=_class_resolver_dictionary(),
not_yet_serializable=[
'Alignment',
'AnyIntegerPowerGateFamily',
'AnyUnitaryGateFamily',
'AxisAngleDecomposition',
'CircuitDag',
'CircuitDiagramInfo',
@@ -39,7 +37,6 @@
'DensityMatrixStepResult',
'DensityMatrixTrialResult',
'ExpressionMap',
'GateFamily',
'Gateset',
'InsertStrategy',
'IonDevice',
@@ -50,7 +47,6 @@
'ListSweep',
'DiagonalGate',
'NeutralAtomDevice',
'ParallelGateFamily',
'PauliInteractionGate',
'PauliStringPhasor',
'PauliSum',
1 change: 1 addition & 0 deletions cirq-google/cirq_google/json_resolver_cache.py
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@ def _class_resolver_dictionary() -> Dict[str, ObjectFactory]:
'SycamoreGate': cirq_google.SycamoreGate,
'GateTabulation': cirq_google.GateTabulation,
'PhysicalZTag': cirq_google.PhysicalZTag,
'FSimGateFamily': cirq_google.FSimGateFamily,
'FloquetPhasedFSimCalibrationOptions': cirq_google.FloquetPhasedFSimCalibrationOptions,
'FloquetPhasedFSimCalibrationRequest': cirq_google.FloquetPhasedFSimCalibrationRequest,
'PhasedFSimCalibrationResult': cirq_google.PhasedFSimCalibrationResult,
52 changes: 52 additions & 0 deletions cirq-google/cirq_google/json_test_data/FSimGateFamily.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
[
{
"cirq_type": "FSimGateFamily",
"gates_to_accept": [
"ISwapPowGate"
],
"gate_types_to_check": [
"FSimGate",
"PhasedFSimGate",
"ISwapPowGate",
"PhasedISwapPowGate",
"CZPowGate",
"IdentityGate"
],
"allow_symbols": false,
"atol": 1e-06
},
{
"cirq_type": "FSimGateFamily",
"gates_to_accept": [
{
"cirq_type": "ISwapPowGate",
"exponent": 1.0,
"global_shift": 0.0
},
{
"cirq_type": "CZPowGate",
"exponent": 1.0,
"global_shift": 0.0
}
],
"gate_types_to_check": [
"FSimGate",
"PhasedFSimGate",
"ISwapPowGate",
"PhasedISwapPowGate",
"CZPowGate",
"IdentityGate"
],
"allow_symbols": false,
"atol": 1e-06
},
{
"cirq_type": "FSimGateFamily",
"gates_to_accept": [],
"gate_types_to_check": [
"IdentityGate"
],
"allow_symbols": true,
"atol": 0.0001
}
]
34 changes: 34 additions & 0 deletions cirq-google/cirq_google/json_test_data/FSimGateFamily.repr
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
[
cirq_google.FSimGateFamily(
gates_to_accept=[cirq.ops.swap_gates.ISwapPowGate],
gate_types_to_check=[
cirq.ops.fsim_gate.FSimGate,
cirq.ops.fsim_gate.PhasedFSimGate,
cirq.ops.swap_gates.ISwapPowGate,
cirq.ops.phased_iswap_gate.PhasedISwapPowGate,
cirq.ops.common_gates.CZPowGate,
cirq.ops.identity.IdentityGate
],
allow_symbols=False,
atol=1e-06
),
cirq_google.FSimGateFamily(
gates_to_accept=[cirq.ISWAP,cirq.CZ],
gate_types_to_check=[
cirq.ops.fsim_gate.FSimGate,
cirq.ops.fsim_gate.PhasedFSimGate,
cirq.ops.swap_gates.ISwapPowGate,
cirq.ops.phased_iswap_gate.PhasedISwapPowGate,
cirq.ops.common_gates.CZPowGate,
cirq.ops.identity.IdentityGate
],
allow_symbols=False,
atol=1e-06
),
cirq_google.FSimGateFamily(
gates_to_accept=[],
gate_types_to_check=[cirq.ops.identity.IdentityGate],
allow_symbols=True,
atol=0.0001
)
]
1 change: 0 additions & 1 deletion cirq-google/cirq_google/json_test_data/spec.py
Original file line number Diff line number Diff line change
@@ -37,7 +37,6 @@
'EngineJob',
'EngineProcessor',
'EngineProgram',
'FSimGateFamily',
'FSimPhaseCorrections',
'NAMED_GATESETS',
'ProtoVersion',
27 changes: 27 additions & 0 deletions cirq-google/cirq_google/ops/fsim_gate_family.py
Original file line number Diff line number Diff line change
@@ -203,6 +203,33 @@ def _value_equality_values_(self) -> Any:
self.atol,
)

def _json_dict_(self):
accept_gates_json = [
gate if not isinstance(gate, type) else cirq.json_cirq_type(gate)
for gate in self.gates_to_accept
]
check_gates_json = [cirq.json_cirq_type(gate) for gate in self.gate_types_to_check]
return {
'gates_to_accept': accept_gates_json,
'gate_types_to_check': check_gates_json,
'allow_symbols': self.allow_symbols,
'atol': self.atol,
}

@classmethod
def _from_json_dict_(cls, gates_to_accept, gate_types_to_check, allow_symbols, atol, **kwargs):
accept_gates = [
gate if not isinstance(gate, str) else cirq.cirq_type_from_json(gate)
for gate in gates_to_accept
]
check_gates = [cirq.cirq_type_from_json(gate) for gate in gate_types_to_check]
return cls(
gates_to_accept=accept_gates,
gate_types_to_check=check_gates,
allow_symbols=allow_symbols,
atol=atol,
)

def _approx_eq_or_symbol(self, lhs: Any, rhs: Any) -> bool:
lhs = lhs if isinstance(lhs, tuple) else (lhs,)
rhs = rhs if isinstance(rhs, tuple) else (rhs,)

0 comments on commit 9053a27

Please sign in to comment.