Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DeviceMetaData class. #4832

Merged
merged 5 commits into from
Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from cirq.devices import (
ConstantQubitNoiseModel,
Device,
DeviceMetadata,
GridQid,
GridQubit,
LineQid,
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/devices/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Types for devices, device-specific qubits, and noise models."""
from cirq.devices.device import (
Device,
DeviceMetadata,
SymmetricalQidPair,
)

Expand Down
77 changes: 76 additions & 1 deletion cirq-core/cirq/devices/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.

import abc
from typing import TYPE_CHECKING, Optional, AbstractSet, cast, FrozenSet, Iterator
from typing import TYPE_CHECKING, Optional, AbstractSet, cast, FrozenSet, Iterator, Iterable

import networkx as nx
from cirq import value
from cirq.devices.grid_qubit import _BaseGridQid
from cirq.devices.line_qubit import _BaseLineQid
Expand Down Expand Up @@ -178,3 +179,77 @@ def __iter__(self) -> Iterator['cirq.Qid']:

def __contains__(self, item: 'cirq.Qid') -> bool:
return item in self.qids


@value.value_equality
class DeviceMetadata:
"""Parent type for all device specific metadata classes."""

def __init__(
self,
qubits: Optional[Iterable['cirq.Qid']] = None,
nx_graph: Optional['nx.graph'] = None,
):
"""Construct a DeviceMetadata object.

Args:
qubits: Optional iterable of `cirq.Qid`s that exist on the device.
nx_graph: Optional `nx.Graph` describing qubit connectivity
on a device. Nodes represent qubits, directed edges indicate
directional coupling, undirected edges indicate bi-directional
coupling.
"""
if qubits is not None:
qubits = frozenset(qubits)
self._qubits_set: Optional[FrozenSet['cirq.Qid']] = cast(
Optional[FrozenSet['cirq.Qid']], qubits
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this would be more concise as

self._qubits_set: Optional[FrozenSet['cirq.Qid']] = (
    None if qubits is None else frozenset(qubits)
)


self._nx_graph = nx_graph

def qubit_set(self) -> Optional[FrozenSet['cirq.Qid']]:
"""Returns a set of qubits on the device, if possible.

Returns:
Frozenset of qubits on device if specified, otherwise None.
"""
return self._qubits_set

def nx_graph(self) -> Optional['nx.Graph']:
"""Returns a nx.Graph where nodes are qubits and edges are couple-able qubits.

Returns:
`nx.Graph` of device connectivity if specified, otherwise None.
"""
return self._nx_graph

def _value_equality_values_(self):
graph_equality = None
if self._nx_graph is not None:
graph_equality = (sorted(self._nx_graph.nodes()), sorted(self._nx_graph.edges()))

qubit_equality = None
if self._qubits_set is not None:
qubit_equality = sorted(list(self._qubits_set))

return qubit_equality, graph_equality

def _json_dict_(self):
graph_payload = ''
if self._nx_graph is not None:
graph_payload = nx.readwrite.json_graph.node_link_data(self._nx_graph)

qubits_payload = ''
if self._qubits_set is not None:
qubits_payload = sorted(list(self._qubits_set))

return {'qubits': qubits_payload, 'nx_graph': graph_payload}

@classmethod
def _from_json_dict_(cls, qubits, nx_graph, **kwargs):
if qubits == '':
qubits = None
graph_obj = None
if nx_graph != '':
graph_obj = nx.readwrite.json_graph.node_link_graph(nx_graph)
return cls(qubits, graph_obj)
28 changes: 28 additions & 0 deletions cirq-core/cirq/devices/device_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=wrong-or-nonexistent-copyright-notice
import pytest
import networkx as nx
import cirq


Expand Down Expand Up @@ -75,3 +76,30 @@ def test_qid_pair():

with pytest.raises(ValueError, match='A QidPair cannot have identical qids.'):
cirq.SymmetricalQidPair(q0, q0)


def test_metadata():
qubits = cirq.LineQubit.range(4)
graph = nx.star_graph(3)
metadata = cirq.DeviceMetadata(qubits, graph)
assert metadata.qubit_set() == frozenset(qubits)
assert metadata.nx_graph() == graph

metadata = cirq.DeviceMetadata()
assert metadata.qubit_set() is None
assert metadata.nx_graph() is None


def test_metadata_json_load_logic():
qubits = cirq.LineQubit.range(4)
graph = nx.star_graph(3)
metadata = cirq.DeviceMetadata(qubits, graph)
str_rep = cirq.to_json(metadata)
assert metadata == cirq.read_json(json_text=str_rep)

qubits = None
graph = None
metadata = cirq.DeviceMetadata(qubits, graph)
str_rep = cirq.to_json(metadata)
output = cirq.read_json(json_text=str_rep)
assert metadata == output
1 change: 1 addition & 0 deletions cirq-core/cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _parallel_gate_op(gate, qubits):
'CZPowGate': cirq.CZPowGate,
'DensePauliString': cirq.DensePauliString,
'DepolarizingChannel': cirq.DepolarizingChannel,
'DeviceMetadata': cirq.DeviceMetadata,
'Duration': cirq.Duration,
'FrozenCircuit': cirq.FrozenCircuit,
'FSimGate': cirq.FSimGate,
Expand Down
9 changes: 2 additions & 7 deletions cirq-core/cirq/protocols/json_serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import ClassVar, Dict, List, Optional, Tuple, Type
from unittest import mock

import networkx as nx
import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -726,13 +727,7 @@ def _eval_repr_data_file(path: pathlib.Path, deprecation_deadline: Optional[str]
if deprecation is not None and deprecation.old_name in content:
ctx_managers.append(deprecation.deprecation_assertion)

imports = {
'cirq': cirq,
'pd': pd,
'sympy': sympy,
'np': np,
'datetime': datetime,
}
imports = {'cirq': cirq, 'pd': pd, 'sympy': sympy, 'np': np, 'datetime': datetime, 'nx': nx}

for m in TESTED_MODULES.keys():
try:
Expand Down
54 changes: 54 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/DeviceMetadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
{
"cirq_type": "DeviceMetadata",
"qubits": [
{
"cirq_type": "LineQubit",
"x": 0
},
{
"cirq_type": "LineQubit",
"x": 1
},
{
"cirq_type": "LineQubit",
"x": 2
},
{
"cirq_type": "LineQubit",
"x": 3
}
],
"nx_graph": {
"directed": false,
"multigraph": false,
"graph": {},
"nodes": [
{
"id": 0
},
{
"id": 1
},
{
"id": 2
},
{
"id": 3
}
],
"links": [
{
"source": 0,
"target": 1
},
{
"source": 0,
"target": 2
},
{
"source": 0,
"target": 3
}
]
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.DeviceMetadata(cirq.LineQubit.range(4), nx.star_graph(3))