Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support serialization of GateFamilies #4715

Merged
merged 5 commits into from
Dec 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cirq-core/cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def _parallel_gate_op(gate, qubits):
import sympy

return {
'AnyIntegerPowerGateFamily': cirq.AnyIntegerPowerGateFamily,
'AmplitudeDampingChannel': cirq.AmplitudeDampingChannel,
'AnyUnitaryGateFamily': cirq.AnyUnitaryGateFamily,
'AsymmetricDepolarizingChannel': cirq.AsymmetricDepolarizingChannel,
'BitFlipChannel': cirq.BitFlipChannel,
'BitstringAccumulator': cirq.work.BitstringAccumulator,
Expand Down Expand Up @@ -81,6 +83,7 @@ def _parallel_gate_op(gate, qubits):
'MutableDensePauliString': cirq.MutableDensePauliString,
'MutablePauliString': cirq.MutablePauliString,
'ObservableMeasuredResult': cirq.work.ObservableMeasuredResult,
'GateFamily': cirq.GateFamily,
'GateOperation': cirq.GateOperation,
'GeneralizedAmplitudeDampingChannel': cirq.GeneralizedAmplitudeDampingChannel,
'GlobalPhaseOperation': cirq.GlobalPhaseOperation,
Expand Down Expand Up @@ -115,6 +118,7 @@ def _parallel_gate_op(gate, qubits):
'_PauliZ': cirq.ops.pauli_gates._PauliZ,
'ParamResolver': cirq.ParamResolver,
'ParallelGate': cirq.ParallelGate,
'ParallelGateFamily': cirq.ParallelGateFamily,
'PauliMeasurementGate': cirq.PauliMeasurementGate,
'PauliString': cirq.PauliString,
'PhaseDampingChannel': cirq.PhaseDampingChannel,
Expand Down
32 changes: 32 additions & 0 deletions cirq-core/cirq/ops/common_gate_families.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ def __repr__(self) -> str:
def _value_equality_values_(self) -> Any:
return self._num_qubits

def _json_dict_(self):
return {'num_qubits': self._num_qubits}

@classmethod
def _from_json_dict_(cls, num_qubits, **kwargs):
return cls(num_qubits)


class AnyIntegerPowerGateFamily(gateset.GateFamily):
"""GateFamily which accepts instances of a given `cirq.EigenGate`, raised to integer power."""
Expand Down Expand Up @@ -87,6 +94,15 @@ def __repr__(self) -> str:
def _value_equality_values_(self) -> Any:
return self.gate

def _json_dict_(self):
return {'gate': self._gate_json()}

@classmethod
def _from_json_dict_(cls, gate, **kwargs):
if isinstance(gate, str):
gate = protocols.cirq_type_from_json(gate)
return cls(gate)


class ParallelGateFamily(gateset.GateFamily):
"""GateFamily which accepts instances of `cirq.ParallelGate` and it's sub_gate.
Expand Down Expand Up @@ -175,3 +191,19 @@ def __repr__(self) -> str:
def _value_equality_values_(self) -> Any:
# `isinstance` is used to ensure the a gate type and gate instance is not compared.
return super()._value_equality_values_() + (self._max_parallel_allowed,)

def _json_dict_(self):
return {
'gate': self._gate_json(),
'name': self.name,
'description': self.description,
'max_parallel_allowed': self._max_parallel_allowed,
}

@classmethod
def _from_json_dict_(cls, gate, name, description, max_parallel_allowed, **kwargs):
if isinstance(gate, str):
gate = protocols.cirq_type_from_json(gate)
return cls(
gate, name=name, description=description, max_parallel_allowed=max_parallel_allowed
)
19 changes: 19 additions & 0 deletions cirq-core/cirq/ops/gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def __init__(
def _gate_str(self, gettr: Callable[[Any], str] = str) -> str:
return _gate_str(self.gate, gettr)

def _gate_json(self) -> Union[raw_types.Gate, str]:
return self.gate if not isinstance(self.gate, type) else protocols.json_cirq_type(self.gate)

def _default_name(self) -> str:
family_type = 'Instance' if isinstance(self.gate, raw_types.Gate) else 'Type'
return f'{family_type} GateFamily: {self._gate_str()}'
Expand Down Expand Up @@ -167,6 +170,22 @@ def _value_equality_values_(self) -> Any:
self._ignore_global_phase,
)

def _json_dict_(self):
return {
'gate': self._gate_json(),
'name': self.name,
'description': self.description,
'ignore_global_phase': self._ignore_global_phase,
}

@classmethod
def _from_json_dict_(cls, gate, name, description, ignore_global_phase, **kwargs):
if isinstance(gate, str):
gate = protocols.cirq_type_from_json(gate)
return cls(
gate, name=name, description=description, ignore_global_phase=ignore_global_phase
)


@value.value_equality()
class Gateset:
Expand Down
8 changes: 8 additions & 0 deletions cirq-core/cirq/ops/gateset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ def test_gate_family_repr_and_str(gate, name, description):
assert g.description in str(g)


@pytest.mark.parametrize('gate', [cirq.X, cirq.XPowGate(), cirq.XPowGate])
@pytest.mark.parametrize('name, description', [(None, None), ('custom_name', 'custom_description')])
def test_gate_family_json(gate, name, description):
g = cirq.GateFamily(gate, name=name, description=description)
g_json = cirq.to_json(g)
assert cirq.read_json(json_text=g_json) == g


def test_gate_family_eq():
eq = cirq.testing.EqualsTester()
eq.add_equality_group(cirq.GateFamily(CustomX))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"cirq_type": "AnyIntegerPowerGateFamily",
"gate": "XPowGate"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.AnyIntegerPowerGateFamily(cirq.ops.common_gates.XPowGate)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"cirq_type": "AnyUnitaryGateFamily",
"num_qubits": 2
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.AnyUnitaryGateFamily(num_qubits = 2)
20 changes: 20 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/GateFamily.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[
{
"cirq_type": "GateFamily",
"gate": "XPowGate",
"name": "Type GateFamily: cirq.ops.common_gates.XPowGate",
"description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.XPowGate)`",
"ignore_global_phase": true
},
{
"cirq_type": "GateFamily",
"gate": {
"cirq_type": "_PauliX",
"exponent": 1.0,
"global_shift": 0.0
},
"name": "XFamily",
"description": "Just the X gate.",
"ignore_global_phase": false
}
]
4 changes: 4 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/GateFamily.repr
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[
cirq.GateFamily(gate=cirq.ops.common_gates.XPowGate, ignore_global_phase=True),
cirq.GateFamily(gate=cirq.X, name="XFamily", description="Just the X gate.", ignore_global_phase=False)
]
20 changes: 20 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/ParallelGateFamily.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[
{
"cirq_type": "ParallelGateFamily",
"gate": "XPowGate",
"name": "INF Parallel Type GateFamily: cirq.ops.common_gates.XPowGate",
"description": "Accepts\n1. `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.XPowGate)` OR\n2. `cirq.ParallelGate` instance `g` s.t. `g.sub_gate` satisfies 1. and `cirq.num_qubits(g) <= INF` OR\n3. `cirq.Operation` instance `op` s.t. `op.gate` satisfies 1. or 2.",
"max_parallel_allowed": null
},
{
"cirq_type": "ParallelGateFamily",
"gate": {
"cirq_type": "_PauliX",
"exponent": 1.0,
"global_shift": 0.0
},
"name": "ParallelXFamily",
"description": "Up to 4 parallel X gates",
"max_parallel_allowed": 4
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[
cirq.ParallelGateFamily(gate=cirq.ops.common_gates.XPowGate, max_parallel_allowed=None),
cirq.ParallelGateFamily(
gate=cirq.X,
name="ParallelXFamily",
description=r'''Up to 4 parallel X gates''',
max_parallel_allowed=4
)
]
4 changes: 0 additions & 4 deletions cirq-core/cirq/protocols/json_test_data/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
resolver_cache=_class_resolver_dictionary(),
not_yet_serializable=[
'Alignment',
'AnyIntegerPowerGateFamily',
'AnyUnitaryGateFamily',
'AxisAngleDecomposition',
'CircuitDag',
'CircuitDiagramInfo',
Expand All @@ -39,7 +37,6 @@
'DensityMatrixStepResult',
'DensityMatrixTrialResult',
'ExpressionMap',
'GateFamily',
'Gateset',
'InsertStrategy',
'IonDevice',
Expand All @@ -50,7 +47,6 @@
'ListSweep',
'DiagonalGate',
'NeutralAtomDevice',
'ParallelGateFamily',
'PauliInteractionGate',
'PauliStringPhasor',
'PauliSum',
Expand Down
1 change: 1 addition & 0 deletions cirq-google/cirq_google/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def _class_resolver_dictionary() -> Dict[str, ObjectFactory]:
'SycamoreGate': cirq_google.SycamoreGate,
'GateTabulation': cirq_google.GateTabulation,
'PhysicalZTag': cirq_google.PhysicalZTag,
'FSimGateFamily': cirq_google.FSimGateFamily,
'FloquetPhasedFSimCalibrationOptions': cirq_google.FloquetPhasedFSimCalibrationOptions,
'FloquetPhasedFSimCalibrationRequest': cirq_google.FloquetPhasedFSimCalibrationRequest,
'PhasedFSimCalibrationResult': cirq_google.PhasedFSimCalibrationResult,
Expand Down
52 changes: 52 additions & 0 deletions cirq-google/cirq_google/json_test_data/FSimGateFamily.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
[
{
"cirq_type": "FSimGateFamily",
"gates_to_accept": [
"ISwapPowGate"
],
"gate_types_to_check": [
"FSimGate",
"PhasedFSimGate",
"ISwapPowGate",
"PhasedISwapPowGate",
"CZPowGate",
"IdentityGate"
],
"allow_symbols": false,
"atol": 1e-06
},
{
"cirq_type": "FSimGateFamily",
"gates_to_accept": [
{
"cirq_type": "ISwapPowGate",
"exponent": 1.0,
"global_shift": 0.0
},
{
"cirq_type": "CZPowGate",
"exponent": 1.0,
"global_shift": 0.0
}
],
"gate_types_to_check": [
"FSimGate",
"PhasedFSimGate",
"ISwapPowGate",
"PhasedISwapPowGate",
"CZPowGate",
"IdentityGate"
],
"allow_symbols": false,
"atol": 1e-06
},
{
"cirq_type": "FSimGateFamily",
"gates_to_accept": [],
"gate_types_to_check": [
"IdentityGate"
],
"allow_symbols": true,
"atol": 0.0001
}
]
34 changes: 34 additions & 0 deletions cirq-google/cirq_google/json_test_data/FSimGateFamily.repr
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
[
cirq_google.FSimGateFamily(
gates_to_accept=[cirq.ops.swap_gates.ISwapPowGate],
gate_types_to_check=[
cirq.ops.fsim_gate.FSimGate,
cirq.ops.fsim_gate.PhasedFSimGate,
cirq.ops.swap_gates.ISwapPowGate,
cirq.ops.phased_iswap_gate.PhasedISwapPowGate,
cirq.ops.common_gates.CZPowGate,
cirq.ops.identity.IdentityGate
],
allow_symbols=False,
atol=1e-06
),
cirq_google.FSimGateFamily(
gates_to_accept=[cirq.ISWAP,cirq.CZ],
gate_types_to_check=[
cirq.ops.fsim_gate.FSimGate,
cirq.ops.fsim_gate.PhasedFSimGate,
cirq.ops.swap_gates.ISwapPowGate,
cirq.ops.phased_iswap_gate.PhasedISwapPowGate,
cirq.ops.common_gates.CZPowGate,
cirq.ops.identity.IdentityGate
],
allow_symbols=False,
atol=1e-06
),
cirq_google.FSimGateFamily(
gates_to_accept=[],
gate_types_to_check=[cirq.ops.identity.IdentityGate],
allow_symbols=True,
atol=0.0001
)
]
1 change: 0 additions & 1 deletion cirq-google/cirq_google/json_test_data/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
'EngineJob',
'EngineProcessor',
'EngineProgram',
'FSimGateFamily',
'FSimPhaseCorrections',
'NAMED_GATESETS',
'ProtoVersion',
Expand Down
27 changes: 27 additions & 0 deletions cirq-google/cirq_google/ops/fsim_gate_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,33 @@ def _value_equality_values_(self) -> Any:
self.atol,
)

def _json_dict_(self):
accept_gates_json = [
gate if not isinstance(gate, type) else cirq.json_cirq_type(gate)
for gate in self.gates_to_accept
]
check_gates_json = [cirq.json_cirq_type(gate) for gate in self.gate_types_to_check]
return {
'gates_to_accept': accept_gates_json,
'gate_types_to_check': check_gates_json,
'allow_symbols': self.allow_symbols,
'atol': self.atol,
}

@classmethod
def _from_json_dict_(cls, gates_to_accept, gate_types_to_check, allow_symbols, atol, **kwargs):
accept_gates = [
gate if not isinstance(gate, str) else cirq.cirq_type_from_json(gate)
for gate in gates_to_accept
]
check_gates = [cirq.cirq_type_from_json(gate) for gate in gate_types_to_check]
return cls(
gates_to_accept=accept_gates,
gate_types_to_check=check_gates,
allow_symbols=allow_symbols,
atol=atol,
)

def _approx_eq_or_symbol(self, lhs: Any, rhs: Any) -> bool:
lhs = lhs if isinstance(lhs, tuple) else (lhs,)
rhs = rhs if isinstance(rhs, tuple) else (rhs,)
Expand Down