Skip to content

Commit

Permalink
[cirqflow] Hardcoded qubit placement (quantumlib#5194)
Browse files Browse the repository at this point in the history
Add a `QubitPlacer` that takes an explicit mapping in its constructor.
  • Loading branch information
mpharrigan authored Apr 22, 2022
1 parent aa36253 commit 5ea1547
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 3 deletions.
1 change: 1 addition & 0 deletions cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
LineTopology,
TiltedSquareLattice,
get_placements,
is_valid_placement,
draw_placements,
)

Expand Down
1 change: 1 addition & 0 deletions cirq/devices/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
LineTopology,
TiltedSquareLattice,
get_placements,
is_valid_placement,
draw_placements,
)

Expand Down
59 changes: 57 additions & 2 deletions cirq/devices/named_topologies.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,18 @@
import abc
import warnings
from dataclasses import dataclass
from typing import Dict, List, Tuple, Any, Sequence, Union, Iterable, TYPE_CHECKING
from typing import (
Dict,
List,
Tuple,
Any,
Sequence,
Union,
Iterable,
TYPE_CHECKING,
Callable,
Optional,
)

import networkx as nx
from matplotlib import pyplot as plt
Expand Down Expand Up @@ -290,13 +301,45 @@ def get_placements(
return small_to_bigs


def _is_valid_placement_helper(
big_graph: nx.Graph, small_mapped: nx.Graph, small_to_big_mapping: Dict
):
"""Helper function for `is_valid_placement` that assumes the mapping of `small_graph` has
already occurred.
This is so we don't duplicate work when checking placements during `draw_placements`.
"""
subgraph = big_graph.subgraph(small_to_big_mapping.values())
return (subgraph.nodes == small_mapped.nodes) and (subgraph.edges == small_mapped.edges)


def is_valid_placement(big_graph: nx.Graph, small_graph: nx.Graph, small_to_big_mapping: Dict):
"""Return whether the given placement is a valid placement of small_graph onto big_graph.
This is done by making sure all the nodes and edges on the mapped version of `small_graph`
are present in `big_graph`.
Args:
big_graph: A larger graph we're placing `small_graph` onto.
small_graph: A smaller, (potential) sub-graph to validate the given mapping.
small_to_big_mapping: A mapping from `small_graph` nodes to `big_graph`
nodes. After the mapping occurs, we check whether all of the mapped nodes and
edges exist on `big_graph`.
"""
small_mapped = nx.relabel_nodes(small_graph, small_to_big_mapping)
return _is_valid_placement_helper(
big_graph=big_graph, small_mapped=small_mapped, small_to_big_mapping=small_to_big_mapping
)


def draw_placements(
big_graph: nx.Graph,
small_graph: nx.Graph,
small_to_big_mappings: Sequence[Dict],
max_plots: int = 20,
axes: Sequence[plt.Axes] = None,
tilted=True,
tilted: bool = True,
bad_placement_callback: Optional[Callable[[plt.Axes, int], None]] = None,
):
"""Draw a visualization of placements from small_graph onto big_graph using Matplotlib.
Expand All @@ -312,6 +355,9 @@ def draw_placements(
`max_plots` plots.
axes: Optional list of matplotlib Axes to contain the drawings.
tilted: Whether to draw gridlike graphs in the ordinary cartesian or tilted plane.
bad_placement_callback: If provided, we check that the given mappings are valid. If not,
this callback is called. The callback should accept `ax` and `i` keyword arguments
for the current axis and mapping index, respectively.
"""
if len(small_to_big_mappings) > max_plots:
# coverage: ignore
Expand All @@ -331,6 +377,15 @@ def draw_placements(
ax = plt.gca()

small_mapped = nx.relabel_nodes(small_graph, small_to_big_map)
if bad_placement_callback is not None:
# coverage: ignore
if not _is_valid_placement_helper(
big_graph=big_graph,
small_mapped=small_mapped,
small_to_big_mapping=small_to_big_map,
):
bad_placement_callback(ax, i)

draw_gridlike(big_graph, ax=ax, tilted=tilted)
draw_gridlike(
small_mapped,
Expand Down
20 changes: 19 additions & 1 deletion cirq/devices/named_topologies_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
import cirq
import networkx as nx
import pytest
from cirq import draw_gridlike, LineTopology, TiltedSquareLattice, get_placements, draw_placements
from cirq import (
draw_gridlike,
LineTopology,
TiltedSquareLattice,
get_placements,
draw_placements,
is_valid_placement,
)


@pytest.mark.parametrize('width, height', list(itertools.product([1, 2, 3, 24], repeat=2)))
Expand Down Expand Up @@ -119,3 +126,14 @@ def test_get_placements():
draw_placements(syc23, topo.graph, placements[::3], axes=axes)
for ax in axes:
ax.scatter.assert_called()


def test_is_valid_placement():
topo = TiltedSquareLattice(4, 2)
syc23 = TiltedSquareLattice(8, 4).graph
placements = get_placements(syc23, topo.graph)
for placement in placements:
assert is_valid_placement(syc23, topo.graph, placement)

bad_placement = topo.nodes_to_gridqubits(offset=(100, 100))
assert not is_valid_placement(syc23, topo.graph, bad_placement)

0 comments on commit 5ea1547

Please sign in to comment.