From 88065168aaf2e2170fdf115dde0378a29949703d Mon Sep 17 00:00:00 2001 From: MichaelBroughton Date: Thu, 13 Jan 2022 15:51:38 -0800 Subject: [PATCH] Add DeviceMetaData class. (#4832) Adds standalone DeviceMetaData class. First step in #4743 . --- cirq/__init__.py | 1 + cirq/devices/__init__.py | 1 + cirq/devices/device.py | 77 ++++++++++++++++++- cirq/devices/device_test.py | 28 +++++++ cirq/json_resolver_cache.py | 1 + cirq/protocols/json_serialization_test.py | 9 +-- .../json_test_data/DeviceMetadata.json | 54 +++++++++++++ .../json_test_data/DeviceMetadata.repr | 1 + 8 files changed, 164 insertions(+), 8 deletions(-) create mode 100644 cirq/protocols/json_test_data/DeviceMetadata.json create mode 100644 cirq/protocols/json_test_data/DeviceMetadata.repr diff --git a/cirq/__init__.py b/cirq/__init__.py index 137ca460eda..5afab01a07d 100644 --- a/cirq/__init__.py +++ b/cirq/__init__.py @@ -81,6 +81,7 @@ from cirq.devices import ( ConstantQubitNoiseModel, Device, + DeviceMetadata, GridQid, GridQubit, LineQid, diff --git a/cirq/devices/__init__.py b/cirq/devices/__init__.py index b1480a9513d..18a7100a035 100644 --- a/cirq/devices/__init__.py +++ b/cirq/devices/__init__.py @@ -15,6 +15,7 @@ """Types for devices, device-specific qubits, and noise models.""" from cirq.devices.device import ( Device, + DeviceMetadata, SymmetricalQidPair, ) diff --git a/cirq/devices/device.py b/cirq/devices/device.py index 05da94b0ca0..701136d13d9 100644 --- a/cirq/devices/device.py +++ b/cirq/devices/device.py @@ -13,8 +13,9 @@ # limitations under the License. import abc -from typing import TYPE_CHECKING, Optional, AbstractSet, cast, FrozenSet, Iterator +from typing import TYPE_CHECKING, Optional, AbstractSet, cast, FrozenSet, Iterator, Iterable +import networkx as nx from cirq import value from cirq.devices.grid_qubit import _BaseGridQid from cirq.devices.line_qubit import _BaseLineQid @@ -178,3 +179,77 @@ def __iter__(self) -> Iterator['cirq.Qid']: def __contains__(self, item: 'cirq.Qid') -> bool: return item in self.qids + + +@value.value_equality +class DeviceMetadata: + """Parent type for all device specific metadata classes.""" + + def __init__( + self, + qubits: Optional[Iterable['cirq.Qid']] = None, + nx_graph: Optional['nx.graph'] = None, + ): + """Construct a DeviceMetadata object. + + Args: + qubits: Optional iterable of `cirq.Qid`s that exist on the device. + nx_graph: Optional `nx.Graph` describing qubit connectivity + on a device. Nodes represent qubits, directed edges indicate + directional coupling, undirected edges indicate bi-directional + coupling. + """ + if qubits is not None: + qubits = frozenset(qubits) + self._qubits_set: Optional[FrozenSet['cirq.Qid']] = ( + None if qubits is None else frozenset(qubits) + ) + + self._nx_graph = nx_graph + + def qubit_set(self) -> Optional[FrozenSet['cirq.Qid']]: + """Returns a set of qubits on the device, if possible. + + Returns: + Frozenset of qubits on device if specified, otherwise None. + """ + return self._qubits_set + + def nx_graph(self) -> Optional['nx.Graph']: + """Returns a nx.Graph where nodes are qubits and edges are couple-able qubits. + + Returns: + `nx.Graph` of device connectivity if specified, otherwise None. + """ + return self._nx_graph + + def _value_equality_values_(self): + graph_equality = None + if self._nx_graph is not None: + graph_equality = (sorted(self._nx_graph.nodes()), sorted(self._nx_graph.edges())) + + qubit_equality = None + if self._qubits_set is not None: + qubit_equality = sorted(list(self._qubits_set)) + + return qubit_equality, graph_equality + + def _json_dict_(self): + graph_payload = '' + if self._nx_graph is not None: + graph_payload = nx.readwrite.json_graph.node_link_data(self._nx_graph) + + qubits_payload = '' + if self._qubits_set is not None: + qubits_payload = sorted(list(self._qubits_set)) + + return {'qubits': qubits_payload, 'nx_graph': graph_payload} + + @classmethod + def _from_json_dict_(cls, qubits, nx_graph, **kwargs): + if qubits == '': + qubits = None + graph_obj = None + if nx_graph != '': + graph_obj = nx.readwrite.json_graph.node_link_graph(nx_graph) + return cls(qubits, graph_obj) diff --git a/cirq/devices/device_test.py b/cirq/devices/device_test.py index f730492b5e3..502c357925e 100644 --- a/cirq/devices/device_test.py +++ b/cirq/devices/device_test.py @@ -1,5 +1,6 @@ # pylint: disable=wrong-or-nonexistent-copyright-notice import pytest +import networkx as nx import cirq @@ -75,3 +76,30 @@ def test_qid_pair(): with pytest.raises(ValueError, match='A QidPair cannot have identical qids.'): cirq.SymmetricalQidPair(q0, q0) + + +def test_metadata(): + qubits = cirq.LineQubit.range(4) + graph = nx.star_graph(3) + metadata = cirq.DeviceMetadata(qubits, graph) + assert metadata.qubit_set() == frozenset(qubits) + assert metadata.nx_graph() == graph + + metadata = cirq.DeviceMetadata() + assert metadata.qubit_set() is None + assert metadata.nx_graph() is None + + +def test_metadata_json_load_logic(): + qubits = cirq.LineQubit.range(4) + graph = nx.star_graph(3) + metadata = cirq.DeviceMetadata(qubits, graph) + str_rep = cirq.to_json(metadata) + assert metadata == cirq.read_json(json_text=str_rep) + + qubits = None + graph = None + metadata = cirq.DeviceMetadata(qubits, graph) + str_rep = cirq.to_json(metadata) + output = cirq.read_json(json_text=str_rep) + assert metadata == output diff --git a/cirq/json_resolver_cache.py b/cirq/json_resolver_cache.py index c60a8cf66d1..2e7cc71797f 100644 --- a/cirq/json_resolver_cache.py +++ b/cirq/json_resolver_cache.py @@ -77,6 +77,7 @@ def _parallel_gate_op(gate, qubits): 'CZPowGate': cirq.CZPowGate, 'DensePauliString': cirq.DensePauliString, 'DepolarizingChannel': cirq.DepolarizingChannel, + 'DeviceMetadata': cirq.DeviceMetadata, 'Duration': cirq.Duration, 'FrozenCircuit': cirq.FrozenCircuit, 'FSimGate': cirq.FSimGate, diff --git a/cirq/protocols/json_serialization_test.py b/cirq/protocols/json_serialization_test.py index 51244fd22c1..c69a0a017d0 100644 --- a/cirq/protocols/json_serialization_test.py +++ b/cirq/protocols/json_serialization_test.py @@ -24,6 +24,7 @@ from typing import ClassVar, Dict, List, Optional, Tuple, Type from unittest import mock +import networkx as nx import numpy as np import pandas as pd import pytest @@ -726,13 +727,7 @@ def _eval_repr_data_file(path: pathlib.Path, deprecation_deadline: Optional[str] if deprecation is not None and deprecation.old_name in content: ctx_managers.append(deprecation.deprecation_assertion) - imports = { - 'cirq': cirq, - 'pd': pd, - 'sympy': sympy, - 'np': np, - 'datetime': datetime, - } + imports = {'cirq': cirq, 'pd': pd, 'sympy': sympy, 'np': np, 'datetime': datetime, 'nx': nx} for m in TESTED_MODULES.keys(): try: diff --git a/cirq/protocols/json_test_data/DeviceMetadata.json b/cirq/protocols/json_test_data/DeviceMetadata.json new file mode 100644 index 00000000000..f02295ae929 --- /dev/null +++ b/cirq/protocols/json_test_data/DeviceMetadata.json @@ -0,0 +1,54 @@ +{ + "cirq_type": "DeviceMetadata", + "qubits": [ + { + "cirq_type": "LineQubit", + "x": 0 + }, + { + "cirq_type": "LineQubit", + "x": 1 + }, + { + "cirq_type": "LineQubit", + "x": 2 + }, + { + "cirq_type": "LineQubit", + "x": 3 + } + ], + "nx_graph": { + "directed": false, + "multigraph": false, + "graph": {}, + "nodes": [ + { + "id": 0 + }, + { + "id": 1 + }, + { + "id": 2 + }, + { + "id": 3 + } + ], + "links": [ + { + "source": 0, + "target": 1 + }, + { + "source": 0, + "target": 2 + }, + { + "source": 0, + "target": 3 + } + ] + } +} diff --git a/cirq/protocols/json_test_data/DeviceMetadata.repr b/cirq/protocols/json_test_data/DeviceMetadata.repr new file mode 100644 index 00000000000..866722f0248 --- /dev/null +++ b/cirq/protocols/json_test_data/DeviceMetadata.repr @@ -0,0 +1 @@ +cirq.DeviceMetadata(cirq.LineQubit.range(4), nx.star_graph(3)) \ No newline at end of file