Skip to content

Commit

Permalink
[runtime] QubitPlacer 2 - RandomDevicePlacer (#4719)
Browse files Browse the repository at this point in the history
As part of the cirqflow runtime, add `RandomDevicePlacer` which uses `cirq.get_placements` to map from a `NamedTopology` graph to a device graph.

This requires shimming over the ability to get a device graph from a `cirq.Device`.

@tanujkhattar @MichaelBroughton feel free to browse while I add tests
  • Loading branch information
mpharrigan authored Jan 19, 2022
1 parent bce3002 commit 7ffc77c
Show file tree
Hide file tree
Showing 9 changed files with 253 additions and 2 deletions.
2 changes: 2 additions & 0 deletions cirq-google/cirq_google/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@
QuantumRuntimeConfiguration,
execute,
QubitPlacer,
CouldNotPlaceError,
NaiveQubitPlacer,
RandomDevicePlacer,
)

from cirq_google import experimental
Expand Down
1 change: 1 addition & 0 deletions cirq-google/cirq_google/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,5 @@ def _class_resolver_dictionary() -> Dict[str, ObjectFactory]:
# pylint: enable=line-too-long
'cirq.google.QuantumRuntimeConfiguration': cirq_google.QuantumRuntimeConfiguration,
'cirq.google.NaiveQubitPlacer': cirq_google.NaiveQubitPlacer,
'cirq.google.RandomDevicePlacer': cirq_google.RandomDevicePlacer,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"cirq_type": "cirq.google.RandomDevicePlacer"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq_google.RandomDevicePlacer()
2 changes: 2 additions & 0 deletions cirq-google/cirq_google/json_test_data/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
'THETA_ZETA_GAMMA_FLOQUET_PHASED_FSIM_CHARACTERIZATION',
'QuantumEngineSampler',
'ValidatingSampler',
'CouldNotPlaceError',
# Abstract:
'ExecutableSpec',
],
Expand All @@ -67,6 +68,7 @@
'SharedRuntimeInfo',
'ExecutableGroupResultFilesystemRecord',
'NaiveQubitPlacer',
'RandomDevicePlacer',
]
},
tested_elsewhere=[
Expand Down
2 changes: 2 additions & 0 deletions cirq-google/cirq_google/workflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,7 @@

from cirq_google.workflow.qubit_placement import (
QubitPlacer,
CouldNotPlaceError,
NaiveQubitPlacer,
RandomDevicePlacer,
)
30 changes: 30 additions & 0 deletions cirq-google/cirq_google/workflow/_device_shim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
from typing import Iterable, cast

import cirq
import networkx as nx


def _gridqubits_to_graph_device(qubits: Iterable[cirq.GridQubit]):
return nx.Graph(
pair for pair in itertools.combinations(qubits, 2) if pair[0].is_adjacent(pair[1])
)


def _Device_dot_get_nx_graph(device: 'cirq.Device') -> nx.Graph:
"""Shim over future `cirq.Device` method to get a NetworkX graph."""
return _gridqubits_to_graph_device(cast(Iterable[cirq.GridQubit], device.qubit_set()))
136 changes: 135 additions & 1 deletion cirq-google/cirq_google/workflow/qubit_placement.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,24 @@

import abc
import dataclasses
from typing import Dict, Any, Tuple, TYPE_CHECKING
from functools import lru_cache
from typing import Dict, Any, Tuple, List, Callable, TYPE_CHECKING

import numpy as np

import cirq
from cirq import _compat
from cirq.devices.named_topologies import get_placements
from cirq_google.workflow._device_shim import _Device_dot_get_nx_graph

if TYPE_CHECKING:
import cirq_google as cg


class CouldNotPlaceError(RuntimeError):
"""Raised if a problem topology could not be placed on a device graph."""


class QubitPlacer(metaclass=abc.ABCMeta):
@abc.abstractmethod
def place_circuit(
Expand Down Expand Up @@ -73,3 +80,130 @@ def _json_dict_(self) -> Dict[str, Any]:

def __repr__(self) -> str:
return _compat.dataclass_repr(self, namespace='cirq_google')


def default_topo_node_to_qubit(node: Any) -> cirq.Qid:
"""The default mapping from `cirq.NamedTopology` nodes and `cirq.Qid`.
There is a correspondence between nodes and the "abstract" Qids
used to construct un-placed circuit. `cirq.get_placements` returns a dictionary
mapping from node to Qid. We use this function to transform it into a mapping
from "abstract" Qid to device Qid. This function encodes the default behavior used by
`RandomDevicePlacer`.
If nodes are tuples of integers, map to `cirq.GridQubit`. Otherwise, try
to map to `cirq.LineQubit` and rely on its validation.
Args:
node: A node from a `cirq.NamedTopology` graph.
Returns:
A `cirq.Qid` appropriate for the node type.
"""

try:
return cirq.GridQubit(*node)
except TypeError:
return cirq.LineQubit(node)


@lru_cache()
def _cached_get_placements(
problem_topo: 'cirq.NamedTopology', device: 'cirq.Device'
) -> List[Dict[Any, 'cirq.Qid']]:
"""Cache `cirq.get_placements` onto the specific device."""
return get_placements(
big_graph=_Device_dot_get_nx_graph(device), small_graph=problem_topo.graph
)


def _get_random_placement(
problem_topology: 'cirq.NamedTopology',
device: 'cirq.Device',
rs: np.random.RandomState,
topo_node_to_qubit_func: Callable[[Any], 'cirq.Qid'] = default_topo_node_to_qubit,
) -> Dict['cirq.Qid', 'cirq.Qid']:
"""Place `problem_topology` randomly onto a device.
This is a helper function used by `RandomDevicePlacer.place_circuit`.
"""
placements = _cached_get_placements(problem_topology, device)
if len(placements) == 0:
raise CouldNotPlaceError
random_i = rs.randint(len(placements))
placement = placements[random_i]
placement_gq = {topo_node_to_qubit_func(k): v for k, v in placement.items()}
return placement_gq


class RandomDevicePlacer(QubitPlacer):
def __init__(
self,
topo_node_to_qubit_func: Callable[[Any], cirq.Qid] = default_topo_node_to_qubit,
):
"""A placement strategy that randomly places circuits onto devices.
Args:
topo_node_to_qubit_func: A function that maps from `cirq.NamedTopology` nodes
to `cirq.Qid`. There is a correspondence between nodes and the "abstract" Qids
used to construct the un-placed circuit. `cirq.get_placements` returns a dictionary
mapping from node to Qid. We use this function to transform it into a mapping
from "abstract" Qid to device Qid. By default: nodes which are tuples correspond
to `cirq.GridQubit`s; otherwise `cirq.LineQubit`.
Note:
The attribute `topo_node_to_qubit_func` is not preserved in JSON serialization. This
bit of plumbing does not affect the placement behavior.
"""
self.topo_node_to_qubit_func = topo_node_to_qubit_func

def place_circuit(
self,
circuit: 'cirq.AbstractCircuit',
problem_topology: 'cirq.NamedTopology',
shared_rt_info: 'cg.SharedRuntimeInfo',
rs: np.random.RandomState,
) -> Tuple['cirq.FrozenCircuit', Dict[Any, 'cirq.Qid']]:
"""Place a circuit with a given topology onto a device via `cirq.get_placements` with
randomized selection of the placement each time.
This requires device information to be present in `shared_rt_info`.
Args:
circuit: The circuit.
problem_topology: The topologies (i.e. connectivity) of the circuit.
shared_rt_info: A `cg.SharedRuntimeInfo` object that contains a `device` attribute
of type `cirq.Device` to enable placement.
rs: A `RandomState` as a source of randomness for random placements.
Returns:
A tuple of a new frozen circuit with the qubits placed and a mapping from input
qubits or nodes to output qubits.
Raises:
ValueError: If `shared_rt_info` does not have a device field.
"""
device = shared_rt_info.device
if device is None:
raise ValueError(
"RandomDevicePlacer requires shared_rt_info.device to be a `cirq.Device`. "
"This should have been set during the initialization phase of `cg.execute`."
)
placement = _get_random_placement(
problem_topology, device, rs=rs, topo_node_to_qubit_func=self.topo_node_to_qubit_func
)
return circuit.unfreeze().transform_qubits(placement).freeze(), placement

@classmethod
def _json_namespace_(cls) -> str:
return 'cirq.google'

def _json_dict_(self) -> Dict[str, Any]:
return cirq.obj_to_dict_helper(self, [])

def __repr__(self) -> str:
return "cirq_google.RandomDevicePlacer()"

def __eq__(self, other):
if isinstance(other, RandomDevicePlacer):
return True
78 changes: 77 additions & 1 deletion cirq-google/cirq_google/workflow/qubit_placement_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

import cirq
import cirq_google as cg
Expand All @@ -36,6 +37,81 @@ def test_naive_qubit_placer():
)
assert circuit is not circuit2
assert circuit == circuit2
assert all(q in cg.Sycamore23.qubit_set() for q in circuit.all_qubits())
assert all(q in cg.Sycamore23.qubit_set() for q in circuit2.all_qubits())
for k, v in mapping.items():
assert k == v


def test_random_device_placer_tilted_square_lattice():
topo = cirq.TiltedSquareLattice(4, 2)
qubits = sorted(topo.nodes_to_gridqubits().values())
circuit = cirq.experiments.random_rotations_between_grid_interaction_layers_circuit(
qubits, depth=8, two_qubit_op_factory=lambda a, b, _: cirq.SQRT_ISWAP(a, b)
)
assert not all(q in cg.Sycamore23.qubit_set() for q in circuit.all_qubits())

qp = cg.RandomDevicePlacer()
circuit2, mapping = qp.place_circuit(
circuit,
problem_topology=topo,
shared_rt_info=cg.SharedRuntimeInfo(run_id='1', device=cg.Sycamore23),
rs=np.random.RandomState(1),
)
assert circuit is not circuit2
assert circuit != circuit2
assert all(q in cg.Sycamore23.qubit_set() for q in circuit2.all_qubits())
for k, v in mapping.items():
assert k != v


def test_random_device_placer_line():
topo = cirq.LineTopology(8)
qubits = cirq.LineQubit.range(8)
circuit = cirq.testing.random_circuit(qubits, n_moments=8, op_density=1.0, random_state=52)

qp = cg.RandomDevicePlacer()
circuit2, mapping = qp.place_circuit(
circuit,
problem_topology=topo,
shared_rt_info=cg.SharedRuntimeInfo(run_id='1', device=cg.Sycamore23),
rs=np.random.RandomState(1),
)
assert circuit is not circuit2
assert circuit != circuit2
assert all(q in cg.Sycamore23.qubit_set() for q in circuit2.all_qubits())
for k, v in mapping.items():
assert k != v


def test_random_device_placer_repr():
cirq.testing.assert_equivalent_repr(cg.RandomDevicePlacer(), global_vals={'cirq_google': cg})


def test_random_device_placer_bad_device():
topo = cirq.LineTopology(8)
qubits = cirq.LineQubit.range(8)
circuit = cirq.testing.random_circuit(qubits, n_moments=8, op_density=1.0, random_state=52)
qp = cg.RandomDevicePlacer()
with pytest.raises(ValueError, match=r'.*shared_rt_info\.device.*'):
qp.place_circuit(
circuit,
problem_topology=topo,
shared_rt_info=cg.SharedRuntimeInfo(run_id='1'),
rs=np.random.RandomState(1),
)


def test_random_device_placer_small_device():
topo = cirq.TiltedSquareLattice(3, 3)
qubits = sorted(topo.nodes_to_gridqubits().values())
circuit = cirq.experiments.random_rotations_between_grid_interaction_layers_circuit(
qubits, depth=8, two_qubit_op_factory=lambda a, b, _: cirq.SQRT_ISWAP(a, b)
)
qp = cg.RandomDevicePlacer()
with pytest.raises(cg.CouldNotPlaceError):
qp.place_circuit(
circuit,
problem_topology=topo,
shared_rt_info=cg.SharedRuntimeInfo(run_id='1', device=cg.Foxtail),
rs=np.random.RandomState(1),
)

0 comments on commit 7ffc77c

Please sign in to comment.