Skip to content

Commit

Permalink
Add json serialization to sweeps (quantumlib#5618)
Browse files Browse the repository at this point in the history
Adds missing json serialization for sweep objects.
  • Loading branch information
dabacon authored Jul 12, 2022
1 parent d631f39 commit 11e77cf
Show file tree
Hide file tree
Showing 15 changed files with 119 additions and 6 deletions.
6 changes: 6 additions & 0 deletions cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def _symmetricalqidpair(qids):
'LineQubit': cirq.LineQubit,
'LineQid': cirq.LineQid,
'LineTopology': cirq.LineTopology,
'Linspace': cirq.Linspace,
'ListSweep': cirq.ListSweep,
'MatrixGate': cirq.MatrixGate,
'MixedUnitaryChannel': cirq.MixedUnitaryChannel,
'MeasurementKey': cirq.MeasurementKey,
Expand Down Expand Up @@ -187,6 +189,8 @@ def _symmetricalqidpair(qids):
'PhasedISwapPowGate': cirq.PhasedISwapPowGate,
'PhasedXPowGate': cirq.PhasedXPowGate,
'PhasedXZGate': cirq.PhasedXZGate,
'Points': cirq.Points,
'Product': cirq.Product,
'ProductState': cirq.ProductState,
'ProductOfSums': cirq.ProductOfSums,
'ProjectorString': cirq.ProjectorString,
Expand Down Expand Up @@ -219,6 +223,7 @@ def _symmetricalqidpair(qids):
'TwoQubitDiagonalGate': cirq.TwoQubitDiagonalGate,
'TwoQubitGateTabulation': cirq.TwoQubitGateTabulation,
'_UnconstrainedDevice': cirq.devices.unconstrained_device._UnconstrainedDevice,
'_Unit': cirq.study.sweeps._Unit,
'VarianceStoppingCriteria': cirq.work.VarianceStoppingCriteria,
'VirtualTag': cirq.VirtualTag,
'WaitGate': cirq.WaitGate,
Expand All @@ -233,6 +238,7 @@ def _symmetricalqidpair(qids):
'YPowGate': cirq.YPowGate,
'YYPowGate': cirq.YYPowGate,
'_ZEigenState': cirq.value.product_state._ZEigenState, # type: ignore
'Zip': cirq.Zip,
'ZPowGate': cirq.ZPowGate,
'ZZPowGate': cirq.ZZPowGate,
# Old types, only supported for backwards-compatibility
Expand Down
7 changes: 7 additions & 0 deletions cirq/protocols/json_test_data/Linspace.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"cirq_type": "Linspace",
"key": "a",
"start": 0,
"stop": 1,
"length": 4
}
1 change: 1 addition & 0 deletions cirq/protocols/json_test_data/Linspace.repr
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.Linspace("a", start=0, stop=1, length=4)
23 changes: 23 additions & 0 deletions cirq/protocols/json_test_data/ListSweep.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"cirq_type": "ListSweep",
"resolver_list": [
{
"cirq_type": "ParamResolver",
"param_dict": [
[
"a",
0.1
]
]
},
{
"cirq_type": "ParamResolver",
"param_dict": [
[
"b",
0.2
]
]
}
]
}
1 change: 1 addition & 0 deletions cirq/protocols/json_test_data/ListSweep.repr
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.ListSweep([cirq.ParamResolver({'a': 0.1}), cirq.ParamResolver({'b': 0.2})])
8 changes: 8 additions & 0 deletions cirq/protocols/json_test_data/Points.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"cirq_type": "Points",
"key": "a",
"points": [
0,
0.4
]
}
1 change: 1 addition & 0 deletions cirq/protocols/json_test_data/Points.repr
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.Points('a', [0, 0.4])
19 changes: 19 additions & 0 deletions cirq/protocols/json_test_data/Product.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"cirq_type": "Product",
"factors": [
{
"cirq_type": "Linspace",
"key": "a",
"start": 0,
"stop": 1,
"length": 2
},
{
"cirq_type": "Linspace",
"key": "b",
"start": 0,
"stop": 2,
"length": 4
}
]
}
1 change: 1 addition & 0 deletions cirq/protocols/json_test_data/Product.repr
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.Product(cirq.Linspace('a', start=0, stop=1, length=2), cirq.Linspace('b', start=0, stop=2, length=4))
19 changes: 19 additions & 0 deletions cirq/protocols/json_test_data/Zip.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"cirq_type": "Zip",
"sweeps": [
{
"cirq_type": "Linspace",
"key": "a",
"start": 0,
"stop": 1,
"length": 2
},
{
"cirq_type": "Linspace",
"key": "b",
"start": 0,
"stop": 2,
"length": 4
}
]
}
1 change: 1 addition & 0 deletions cirq/protocols/json_test_data/Zip.repr
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.Zip(cirq.Linspace('a', start=0, stop=1, length=2), cirq.Linspace('b', start=0, stop=2, length=4))
3 changes: 3 additions & 0 deletions cirq/protocols/json_test_data/_Unit.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"cirq_type": "_Unit"
}
1 change: 1 addition & 0 deletions cirq/protocols/json_test_data/_Unit.repr
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.UnitSweep
7 changes: 1 addition & 6 deletions cirq/protocols/json_test_data/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,10 @@
'KakDecomposition',
'LinearCombinationOfGates',
'LinearCombinationOfOperations',
'Linspace',
'ListSweep',
'PauliSumCollector',
'PauliSumExponential',
'PeriodicValue',
'PointOptimizationSummary',
'Points',
'Product',
'QasmArgs',
'QasmOutput',
'QuantumState',
Expand All @@ -59,10 +55,8 @@
'TextDiagramDrawer',
'Timestamp',
'TwoQubitGateTabulationResult',
'UnitSweep',
'StateVectorTrialResult',
'ZerosSampler',
'Zip',
],
should_not_be_serialized=[
# Heatmaps
Expand Down Expand Up @@ -105,6 +99,7 @@
'SimulatesFinalState',
'StateVectorStepResult',
'StepResultBase',
'UnitSweep',
'NamedTopology',
# protocols:
'HasJSONNamespace',
Expand Down
27 changes: 27 additions & 0 deletions cirq/study/sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import itertools
import sympy

from cirq import protocols
from cirq._doc import document
from cirq.study import resolver

Expand Down Expand Up @@ -196,6 +197,9 @@ def param_tuples(self) -> Iterator[Params]:
def __repr__(self) -> str:
return 'cirq.UnitSweep'

def _json_dict_(self) -> Dict[str, Any]:
return {}


UnitSweep = _Unit()
document(UnitSweep, """The singleton sweep with no parameters.""")
Expand Down Expand Up @@ -261,6 +265,13 @@ def __str__(self) -> str:
factor_strs.append(factor_str)
return ' * '.join(factor_strs)

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['factors'])

@classmethod
def _from_json_dict_(cls, factors, **kwargs):
return Product(*factors)


class Zip(Sweep):
"""Zip product (direct sum) of one or more sweeps.
Expand Down Expand Up @@ -311,6 +322,13 @@ def __str__(self) -> str:
return 'Zip()'
return ' + '.join(str(s) if isinstance(s, Product) else repr(s) for s in self.sweeps)

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['sweeps'])

@classmethod
def _from_json_dict_(cls, sweeps, **kwargs):
return Zip(*sweeps)


class SingleSweep(Sweep):
"""A simple sweep over one parameter with values from an iterator."""
Expand Down Expand Up @@ -364,6 +382,9 @@ def _values(self) -> Iterator[float]:
def __repr__(self) -> str:
return f'cirq.Points({self.key!r}, {self.points!r})'

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ["key", "points"])


class Linspace(SingleSweep):
"""A simple sweep over linearly-spaced values."""
Expand Down Expand Up @@ -399,6 +420,9 @@ def __repr__(self) -> str:
f'stop={self.stop!r}, length={self.length!r})'
)

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ["key", "start", "stop", "length"])


class ListSweep(Sweep):
"""A wrapper around a list of `ParamResolver`s."""
Expand Down Expand Up @@ -444,6 +468,9 @@ def param_tuples(self) -> Iterator[Params]:
def __repr__(self) -> str:
return f'cirq.ListSweep({self.resolver_list!r})'

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ["resolver_list"])


def _params_without_symbols(resolver: resolver.ParamResolver) -> Params:
for sym, val in resolver.param_dict.items():
Expand Down

0 comments on commit 11e77cf

Please sign in to comment.