Skip to content

Commit

Permalink
Merge pull request #250 from SoshunNaito/feature_subgraph_isomorphism
Browse files Browse the repository at this point in the history
add StaticPlacementPass
  • Loading branch information
edyounis authored Jun 17, 2024
2 parents 6974367 + f8584e0 commit f5abe8e
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 1 deletion.
2 changes: 2 additions & 0 deletions bqskit/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@
from bqskit.passes.mapping.layout.pam import PAMLayoutPass
from bqskit.passes.mapping.layout.sabre import GeneralizedSabreLayoutPass
from bqskit.passes.mapping.placement.greedy import GreedyPlacementPass
from bqskit.passes.mapping.placement.static import StaticPlacementPass
from bqskit.passes.mapping.placement.trivial import TrivialPlacementPass
from bqskit.passes.mapping.routing.pam import PAMRoutingPass
from bqskit.passes.mapping.routing.sabre import GeneralizedSabreRoutingPass
Expand Down Expand Up @@ -364,6 +365,7 @@
'GeneralizedSabreLayoutPass',
'GreedyPlacementPass',
'TrivialPlacementPass',
'StaticPlacementPass',
'GeneralizedSabreRoutingPass',
'SetModelPass',
'U3Decomposition',
Expand Down
2 changes: 2 additions & 0 deletions bqskit/passes/mapping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from bqskit.passes.mapping.layout.pam import PAMLayoutPass
from bqskit.passes.mapping.layout.sabre import GeneralizedSabreLayoutPass
from bqskit.passes.mapping.placement.greedy import GreedyPlacementPass
from bqskit.passes.mapping.placement.static import StaticPlacementPass
from bqskit.passes.mapping.placement.trivial import TrivialPlacementPass
from bqskit.passes.mapping.routing.pam import PAMRoutingPass
from bqskit.passes.mapping.routing.sabre import GeneralizedSabreRoutingPass
Expand All @@ -22,6 +23,7 @@
'GeneralizedSabreLayoutPass',
'GreedyPlacementPass',
'TrivialPlacementPass',
'StaticPlacementPass',
'GeneralizedSabreRoutingPass',
'SetModelPass',
'ApplyPlacement',
Expand Down
3 changes: 2 additions & 1 deletion bqskit/passes/mapping/placement/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

from bqskit.passes.mapping.placement.greedy import GreedyPlacementPass
from bqskit.passes.mapping.placement.static import StaticPlacementPass
from bqskit.passes.mapping.placement.trivial import TrivialPlacementPass

__all__ = ['GreedyPlacementPass', 'TrivialPlacementPass']
__all__ = ['GreedyPlacementPass', 'TrivialPlacementPass', 'StaticPlacementPass']
135 changes: 135 additions & 0 deletions bqskit/passes/mapping/placement/static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""This module implements the StaticPlacementPass class."""
from __future__ import annotations

import logging
import time

from bqskit.compiler.basepass import BasePass
from bqskit.compiler.passdata import PassData
from bqskit.ir.circuit import Circuit
from bqskit.qis.graph import CouplingGraph

_logger = logging.getLogger(__name__)


class StaticPlacementPass(BasePass):
"""Find a subgraph monomorphic to the coupling graph so that no SWAPs are
needed."""

def __init__(self, timeout_sec: float = 10) -> None:
self.timeout_sec = timeout_sec

def _find_monomorphic_subgraph(
self,
time_limit: float,
physical_graph: CouplingGraph,
num_logical_qudits: int,
minimal_degrees: list[int],
connected_indices: list[list[int]],
current_placement: list[int] = [],
current_index: int = 0,
) -> list[int]:
"""Recursively find a monomorphic subgraph."""
if current_index == num_logical_qudits:
return current_placement

if time.time() > time_limit:
return []

# Find all possible placements for the current logical qudit
candidate_indices = set()

# Filter out occupied qudits and qudits with insufficient degrees
physical_degrees = physical_graph.get_qudit_degrees()
for x in range(physical_graph.num_qudits):
if (
physical_degrees[x] >= minimal_degrees[current_index]
and x not in current_placement
):
candidate_indices.add(x)

# Filter out qudits that are not connected to previous logical qudits
for i in connected_indices[current_index]:
candidate_indices &= set(
physical_graph.get_neighbors_of(current_placement[i]),
)

# Try all possible placements for the current logical qudit
for x in candidate_indices:
new_placement = current_placement + [x]
result = self._find_monomorphic_subgraph(
time_limit,
physical_graph,
num_logical_qudits,
minimal_degrees,
connected_indices,
new_placement,
current_index + 1,
)
if len(result) == num_logical_qudits:
return result

# If no valid placement is found, return an empty list
return []

def find_monomorphic_subgraph(
self,
physical_graph: CouplingGraph,
logical_graph: CouplingGraph,
) -> list[int]:
"""Try all possible placements."""

# To be optimized later
logical_qubit_order = list(range(logical_graph.num_qudits))

minimum_degrees = [
logical_graph.get_qudit_degrees()[i] for i in logical_qubit_order
]
connected_indices: list[list[int]] = [
[] for _ in range(logical_graph.num_qudits)
]
for i in range(logical_graph.num_qudits):
for j in range(i):
if logical_qubit_order[j] in logical_graph.get_neighbors_of(
logical_qubit_order[i],
):
connected_indices[i].append(j)

# Find a monomorphic subgraph
start_time = time.time()
index_to_physical = self._find_monomorphic_subgraph(
start_time + self.timeout_sec,
physical_graph,
logical_graph.num_qudits,
minimum_degrees,
connected_indices,
)
_logger.info(f'elapsed time: {time.time() - start_time}')
if len(index_to_physical) == 0:
return []

# Convert the result to a placement
placement = [-1] * logical_graph.num_qudits
for i, x in enumerate(logical_qubit_order):
placement[x] = index_to_physical[i]
return placement

async def run(self, circuit: Circuit, data: PassData) -> None:
"""Perform the pass's operation, see :class:`BasePass` for more."""
physical_graph = data.model.coupling_graph
logical_graph = circuit.coupling_graph

# Find an monomorphic subgraph
placement = self.find_monomorphic_subgraph(
physical_graph, logical_graph,
)

# Set the placement if it is valid
if len(placement) == logical_graph.num_qudits and all(
placement[e[1]] in physical_graph.get_neighbors_of(placement[e[0]])
for e in logical_graph
):
data.placement = placement
_logger.info(f'Placed qudits on {data.placement}')
else:
_logger.info('No valid placement found')
47 changes: 47 additions & 0 deletions tests/passes/mapping/test_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

import pytest

from bqskit.compiler import Compiler
from bqskit.compiler import MachineModel
from bqskit.ir.circuit import Circuit
from bqskit.ir.gates import CNOTGate
from bqskit.passes import ApplyPlacement
from bqskit.passes import GreedyPlacementPass
from bqskit.passes import IfThenElsePass
from bqskit.passes import LogPass
from bqskit.passes import SetModelPass
from bqskit.passes import StaticPlacementPass
from bqskit.passes.control.predicates import PhysicalPredicate
from bqskit.qis import CouplingGraph


def circular_circuit(n: int) -> Circuit:
circuit = Circuit(n)
for i in range(n):
circuit.append_gate(CNOTGate(), [i, (i + 1) % n])
return circuit


@pytest.mark.parametrize(
['grid_size', 'logical_qudits'],
sum([[(n, i) for i in range(2, n**2, 2)] for n in range(2, 8)], [])
+ sum([[(n, i) for i in range(3, n**2, 2)] for n in range(2, 6)], []),
)
def test_circular_to_grid(
grid_size: int, logical_qudits: int, compiler: Compiler,
) -> None:
circuit = circular_circuit(logical_qudits)
cg = CouplingGraph.grid(grid_size, grid_size)
model = MachineModel(grid_size**2, cg)
workflow = [
SetModelPass(model),
StaticPlacementPass(timeout_sec=1.0),
IfThenElsePass(
PhysicalPredicate(),
[LogPass('Static Placement Found')],
[LogPass('Greedy Placement Required'), GreedyPlacementPass()],
),
ApplyPlacement(),
]
compiler.compile(circuit, workflow)

0 comments on commit f5abe8e

Please sign in to comment.