Skip to content

Commit

Permalink
Deprecate device.qubit_set in cirq_google. (quantumlib#4940)
Browse files Browse the repository at this point in the history
Yet more of quantumlib#4744 .

It also looks like now we can get rid of the device shim @mpharrigan , @dstrain115 (will leave to seperate PR).
  • Loading branch information
MichaelBroughton authored and rht committed May 1, 2023
1 parent 0ae5741 commit 683551c
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 10 deletions.
4 changes: 4 additions & 0 deletions cirq-google/cirq_google/devices/serializable_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def metadata(self) -> cirq.GridDeviceMetadata:
"""Get metadata information for device."""
return self._metadata

@_compat.deprecated(
fix='Please use metadata.qubit_set if applicable.',
deadline='v0.15',
)
def qubit_set(self) -> FrozenSet[cirq.Qid]:
return frozenset(self.qubits)

Expand Down
10 changes: 9 additions & 1 deletion cirq-google/cirq_google/devices/serializable_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ def test_gate_definition_equality():
eq.add_equality_group(cirq.X)


def test_qubit_set_deprecated():
foxtail = cg.SerializableDevice.from_proto(
proto=cg.devices.known_devices.FOXTAIL_PROTO, gate_sets=[cg.XMON]
)
with cirq.testing.assert_deprecated('qubit_set', deadline='v0.15'):
_ = foxtail.qubit_set()


def test_foxtail():
valid_qubit1 = cirq.GridQubit(0, 0)
valid_qubit2 = cirq.GridQubit(1, 0)
Expand All @@ -143,7 +151,7 @@ def test_foxtail():
foxtail = cg.SerializableDevice.from_proto(
proto=cg.devices.known_devices.FOXTAIL_PROTO, gate_sets=[cg.XMON]
)
assert foxtail.qubit_set() == frozenset(cirq.GridQubit.rect(2, 11, 0, 0))
assert foxtail.metadata.qubit_set == frozenset(cirq.GridQubit.rect(2, 11, 0, 0))
foxtail.validate_operation(cirq.X(valid_qubit1))
foxtail.validate_operation(cirq.X(valid_qubit2))
foxtail.validate_operation(cirq.X(valid_qubit3))
Expand Down
4 changes: 4 additions & 0 deletions cirq-google/cirq_google/devices/xmon_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def metadata(self) -> cirq.GridDeviceMetadata:
"""Return the metadata for this device"""
return self._metadata

@_compat.deprecated(
fix='Use metadata.qubit_set if applicable.',
deadline='v0.15',
)
def qubit_set(self) -> FrozenSet[cirq.GridQubit]:
return self.qubits

Expand Down
7 changes: 7 additions & 0 deletions cirq-google/cirq_google/devices/xmon_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ def test_device_metadata():
)


@mock.patch.dict(os.environ, clear='CIRQ_TESTING')
def test_qubit_set_deprecated():
d = square_device(2, 2)
with cirq.testing.assert_deprecated('qubit_set', deadline='v0.15'):
_ = d.qubit_set()


@mock.patch.dict(os.environ, clear='CIRQ_TESTING')
def test_init():
d = square_device(2, 2, holes=[cirq.GridQubit(1, 1)])
Expand Down
4 changes: 2 additions & 2 deletions cirq-google/cirq_google/engine/virtual_engine_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@

def _create_perfect_calibration(device: cirq.Device) -> calibration.Calibration:
all_metrics: calibration.ALL_METRICS = {}
qubit_set = device.qubit_set()
if qubit_set is None:
if device.metadata is None:
raise ValueError('Devices for noiseless Virtual Engine must have qubits')
qubit_set = device.metadata.qubit_set
qubits = [cast(cirq.GridQubit, q) for q in qubit_set]
for name in METRICS_1Q:
all_metrics[name] = {(q,): [0.0] for q in qubits}
Expand Down
6 changes: 4 additions & 2 deletions cirq-google/cirq_google/workflow/_device_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import itertools
from typing import Iterable, cast
from typing import Iterable

import cirq
import networkx as nx
Expand All @@ -27,4 +27,6 @@ def _gridqubits_to_graph_device(qubits: Iterable[cirq.GridQubit]):

def _Device_dot_get_nx_graph(device: 'cirq.Device') -> nx.Graph:
"""Shim over future `cirq.Device` method to get a NetworkX graph."""
return _gridqubits_to_graph_device(cast(Iterable[cirq.GridQubit], device.qubit_set()))
if device.metadata is not None:
return device.metadata.nx_graph
raise ValueError('Supplied device must contain metadata.')
35 changes: 30 additions & 5 deletions cirq-google/cirq_google/workflow/qubit_placement_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
class TestDevice(cirq.Device):
def __init__(self):
self.qubits = cirq.GridQubit.rect(2, 8)
neighbors = [(a, b) for a in self.qubits for b in self.qubits if a.is_adjacent(b)]
self._metadata = cirq.GridDeviceMetadata(neighbors, cirq.Gateset(cirq.H))

@property
def metadata(self):
return self._metadata


def test_naive_qubit_placer():
Expand All @@ -31,7 +37,7 @@ def test_naive_qubit_placer():
qubits, depth=8, two_qubit_op_factory=lambda a, b, _: cirq.SQRT_ISWAP(a, b)
)

assert all(q in cg.Sycamore23.qubit_set() for q in circuit.all_qubits())
assert all(q in cg.Sycamore23.metadata.qubit_set for q in circuit.all_qubits())

qp = cg.NaiveQubitPlacer()
circuit2, mapping = qp.place_circuit(
Expand All @@ -42,7 +48,7 @@ def test_naive_qubit_placer():
)
assert circuit is not circuit2
assert circuit == circuit2
assert all(q in cg.Sycamore23.qubit_set() for q in circuit2.all_qubits())
assert all(q in cg.Sycamore23.metadata.qubit_set for q in circuit2.all_qubits())
for k, v in mapping.items():
assert k == v

Expand All @@ -53,7 +59,7 @@ def test_random_device_placer_tilted_square_lattice():
circuit = cirq.experiments.random_rotations_between_grid_interaction_layers_circuit(
qubits, depth=8, two_qubit_op_factory=lambda a, b, _: cirq.SQRT_ISWAP(a, b)
)
assert not all(q in cg.Sycamore23.qubit_set() for q in circuit.all_qubits())
assert not all(q in cg.Sycamore23.metadata.qubit_set for q in circuit.all_qubits())

qp = cg.RandomDevicePlacer()
circuit2, mapping = qp.place_circuit(
Expand All @@ -64,7 +70,7 @@ def test_random_device_placer_tilted_square_lattice():
)
assert circuit is not circuit2
assert circuit != circuit2
assert all(q in cg.Sycamore23.qubit_set() for q in circuit2.all_qubits())
assert all(q in cg.Sycamore23.metadata.qubit_set for q in circuit2.all_qubits())
for k, v in mapping.items():
assert k != v

Expand All @@ -83,7 +89,7 @@ def test_random_device_placer_line():
)
assert circuit is not circuit2
assert circuit != circuit2
assert all(q in cg.Sycamore23.qubit_set() for q in circuit2.all_qubits())
assert all(q in cg.Sycamore23.metadata.qubit_set for q in circuit2.all_qubits())
for k, v in mapping.items():
assert k != v

Expand Down Expand Up @@ -120,3 +126,22 @@ def test_random_device_placer_small_device():
shared_rt_info=cg.SharedRuntimeInfo(run_id='1', device=TestDevice()),
rs=np.random.RandomState(1),
)


def test_device_missing_metadata():
class BadDevice(cirq.Device):
pass

topo = cirq.TiltedSquareLattice(3, 3)
qubits = sorted(topo.nodes_to_gridqubits().values())
circuit = cirq.experiments.random_rotations_between_grid_interaction_layers_circuit(
qubits, depth=8, two_qubit_op_factory=lambda a, b, _: cirq.SQRT_ISWAP(a, b)
)
qp = cg.RandomDevicePlacer()
with pytest.raises(ValueError):
qp.place_circuit(
circuit,
problem_topology=topo,
shared_rt_info=cg.SharedRuntimeInfo(run_id='1', device=BadDevice()),
rs=np.random.RandomState(1),
)

0 comments on commit 683551c

Please sign in to comment.