Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register Workflows #269

Merged
merged 29 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
649f418
Added workflow_registry and register_workflow
mtweiden Aug 14, 2024
b4e5366
compile checks workflow_registry
mtweiden Aug 14, 2024
5912fdf
Pre-commit
mtweiden Aug 14, 2024
18929e1
Workflows can be registered through GateSets
mtweiden Aug 16, 2024
bb3febb
GateSets are hashable
mtweiden Aug 16, 2024
3694184
Check for Gate sequences
mtweiden Aug 16, 2024
888e0b2
Tests for register_workflow
mtweiden Aug 16, 2024
f007e6b
pre-commit
mtweiden Aug 16, 2024
0a33d18
Added test for optimization_level=2
mtweiden Aug 16, 2024
a8b5e25
Renamed workflow_registry -> _workflow_registry
mtweiden Aug 20, 2024
ae23a05
Added clear_register
mtweiden Aug 20, 2024
6e59f4b
register -> registry
mtweiden Aug 20, 2024
92364bd
Fixed spelling error
mtweiden Aug 28, 2024
a74a8aa
Added is_workflow static function
mtweiden Aug 28, 2024
6ff5f9a
Workflow checking is done in Workflow construction
mtweiden Aug 28, 2024
3725b22
No default optimization level
mtweiden Aug 28, 2024
6b7da3c
Permutation robust Gateset hash
mtweiden Aug 28, 2024
99dfee1
Gateset hash test
mtweiden Aug 28, 2024
1e02804
Moved documentation
mtweiden Aug 28, 2024
b6219a9
Removed clear_registry function
mtweiden Aug 28, 2024
6e31654
Removed clear_registry function
mtweiden Aug 28, 2024
6e01010
Changed test
mtweiden Aug 28, 2024
46972c1
Fixed import global conflict
mtweiden Aug 28, 2024
433a108
MachineModels registered in _compile_registry
mtweiden Aug 28, 2024
74e53a0
only considers registered MachineModels
mtweiden Aug 28, 2024
cfa229d
Updated tests for _compile_registry
mtweiden Aug 28, 2024
47fa2ed
Removed unused import
mtweiden Aug 28, 2024
bae9fee
pre-commit
mtweiden Aug 28, 2024
1f94658
Fixed imports
mtweiden Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions bqskit/compiler/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from bqskit.compiler.compiler import Compiler
from bqskit.compiler.machine import MachineModel
from bqskit.compiler.passdata import PassData
from bqskit.compiler.registry import _compile_registry
from bqskit.compiler.workflow import Workflow
from bqskit.compiler.workflow import WorkflowLike
from bqskit.ir.circuit import Circuit
Expand Down Expand Up @@ -669,6 +670,12 @@ def build_workflow(
if model is None:
model = MachineModel(input.num_qudits, radixes=input.radixes)

# Use a registered workflow if model is found in the registry for a given
# optimization_level
if model in _compile_registry:
if optimization_level in _compile_registry[model]:
return _compile_registry[model][optimization_level]

if isinstance(input, Circuit):
if input.num_qudits > max_synthesis_size:
if any(
Expand Down
4 changes: 4 additions & 0 deletions bqskit/compiler/gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,5 +231,9 @@ def __repr__(self) -> str:
"""Detailed representation of the GateSet."""
return self._gates.__repr__().replace('frozenset', 'GateSet')

def __hash__(self) -> int:
"""Hash of the GateSet."""
return hash(tuple(sorted([g.name for g in self._gates])))


GateSetLike = Union[GateSet, Iterable[Gate], Gate]
64 changes: 64 additions & 0 deletions bqskit/compiler/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Register MachineModel specific default workflows."""
from __future__ import annotations

import warnings

from bqskit.compiler.machine import MachineModel
from bqskit.compiler.workflow import Workflow
from bqskit.compiler.workflow import WorkflowLike


_compile_registry: dict[MachineModel, dict[int, Workflow]] = {}


def register_workflow(
key: MachineModel,
workflow: WorkflowLike,
optimization_level: int,
) -> None:
"""
Register a workflow for a given MachineModel.

The _compile_registry enables MachineModel specific workflows to be
registered for use in the `bqskit.compile` method. _compile_registry maps
MachineModels a dictionary of Workflows which are indexed by optimization
level. This object should not be accessed directly by the user, but
instead through the `register_workflow` function.

Args:
key (MachineModel): A MachineModel to register the workflow under.
If a circuit is compiled targeting this machine or gate set, the
registered workflow will be used.

workflow (list[BasePass]): The workflow or list of passes that will
be executed if the MachineModel in a call to `compile` matches
`key`. If `key` is already registered, a warning will be logged.

optimization_level ptional[int): The optimization level with which
to register the workflow. If no level is provided, the Workflow
will be registered as level 1.

Example:
model_t = SpecificMachineModel(num_qudits, radixes)
workflow = [QuickPartitioner(3), NewFangledOptimization()]
register_workflow(model_t, workflow, level)
...
new_circuit = compile(circuit, model_t, optimization_level=level)

Raises:
Warning: If a workflow for a given optimization_level is overwritten.
"""
workflow = Workflow(workflow)

global _compile_registry
new_workflow = {optimization_level: workflow}
if key in _compile_registry:
if optimization_level in _compile_registry[key]:
m = f'Overwritting workflow for {key} at level '
m += f'{optimization_level}. If multiple Namespace packages are '
m += 'installed, ensure that their __init__.py files do not '
m += 'attempt to overwrite the same default Workflows.'
warnings.warn(m)
_compile_registry[key].update(new_workflow)
else:
_compile_registry[key] = new_workflow
6 changes: 6 additions & 0 deletions bqskit/compiler/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ def name(self) -> str:
"""The name of the pass."""
return self._name or self.__class__.__name__

@staticmethod
def is_workflow(workflow: WorkflowLike) -> bool:
if not is_iterable(workflow):
return isinstance(workflow, BasePass)
return all(isinstance(p, BasePass) for p in workflow)

def __str__(self) -> str:
name_seq = f'Workflow: {self.name}\n\t'
pass_strs = [
Expand Down
13 changes: 13 additions & 0 deletions tests/compiler/test_gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,3 +522,16 @@ def test_gate_set_repr() -> None:
repr(gate_set) == 'GateSet({CNOTGate, U3Gate})'
or repr(gate_set) == 'GateSet({U3Gate, CNOTGate})'
)


def test_gate_set_hash() -> None:
gate_set_1 = GateSet({CNOTGate(), U3Gate()})
gate_set_2 = GateSet({U3Gate(), CNOTGate()})
gate_set_3 = GateSet({U3Gate(), CNOTGate(), RZGate()})

h1 = hash(gate_set_1)
h2 = hash(gate_set_2)
h3 = hash(gate_set_3)

assert h1 == h2
assert h1 != h3
119 changes: 119 additions & 0 deletions tests/compiler/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""This file tests the register_workflow function."""
from __future__ import annotations

from itertools import combinations
from random import choice

import pytest
from numpy import allclose

from bqskit.compiler.compile import compile
from bqskit.compiler.machine import MachineModel
from bqskit.compiler.registry import _compile_registry
from bqskit.compiler.registry import register_workflow
from bqskit.compiler.workflow import Workflow
from bqskit.compiler.workflow import WorkflowLike
from bqskit.ir import Circuit
from bqskit.ir import Gate
from bqskit.ir.gates import CZGate
from bqskit.ir.gates import HGate
from bqskit.ir.gates import RZGate
from bqskit.ir.gates import U3Gate
from bqskit.passes import QSearchSynthesisPass
from bqskit.passes import QuickPartitioner
from bqskit.passes import ScanningGateRemovalPass


def machine_match(mach_a: MachineModel, mach_b: MachineModel) -> bool:
if mach_a.num_qudits != mach_b.num_qudits:
return False
if mach_a.radixes != mach_b.radixes:
return False
if mach_a.coupling_graph != mach_b.coupling_graph:
return False
if mach_a.gate_set != mach_b.gate_set:
return False
return True


def unitary_match(unit_a: Circuit, unit_b: Circuit) -> bool:
return allclose(unit_a.get_unitary(), unit_b.get_unitary(), atol=1e-5)


def workflow_match(
workflow_a: WorkflowLike,
workflow_b: WorkflowLike,
) -> bool:
if not isinstance(workflow_a, Workflow):
workflow_a = Workflow(workflow_a)
if not isinstance(workflow_b, Workflow):
workflow_b = Workflow(workflow_b)
if len(workflow_a) != len(workflow_b):
return False
for a, b in zip(workflow_a, workflow_b):
if a.name != b.name:
return False
return True


def simple_circuit(num_qudits: int, gate_set: list[Gate]) -> Circuit:
circ = Circuit(num_qudits)
gate = choice(gate_set)
if gate.num_qudits == 1:
loc = choice(range(num_qudits))
else:
loc = choice(list(combinations(range(num_qudits), 2))) # type: ignore
gate_inv = gate.get_inverse()
circ.append_gate(gate, loc)
circ.append_gate(gate_inv, loc)
return circ


class TestRegisterWorkflow:

@pytest.fixture(autouse=True)
def setup(self) -> None:
# global _compile_registry
_compile_registry.clear()

def test_register_workflow(self) -> None:
global _compile_registry
assert _compile_registry == {}
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QuickPartitioner(), ScanningGateRemovalPass()]
register_workflow(machine, workflow, 1)
assert machine in _compile_registry
assert 1 in _compile_registry[machine]
assert workflow_match(_compile_registry[machine][1], workflow)

def test_custom_compile_machine(self) -> None:
global _compile_registry
assert _compile_registry == {}
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QuickPartitioner(2)]
register_workflow(machine, workflow, 1)
circuit = simple_circuit(num_qudits, gateset)
result = compile(circuit, machine)
assert unitary_match(result, circuit)
assert result.num_operations > 0
assert result.gate_counts != circuit.gate_counts
result.unfold_all()
assert result.gate_counts == circuit.gate_counts

def test_custom_opt_level(self) -> None:
global _compile_registry
assert _compile_registry == {}
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QSearchSynthesisPass()]
register_workflow(machine, workflow, 2)
circuit = simple_circuit(num_qudits, gateset)
result = compile(circuit, machine, optimization_level=2)
assert unitary_match(result, circuit)
assert result.gate_counts != circuit.gate_counts
assert U3Gate() in result.gate_set
Loading