Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register Workflows #269

Merged
merged 29 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
649f418
Added workflow_registry and register_workflow
mtweiden Aug 14, 2024
b4e5366
compile checks workflow_registry
mtweiden Aug 14, 2024
5912fdf
Pre-commit
mtweiden Aug 14, 2024
18929e1
Workflows can be registered through GateSets
mtweiden Aug 16, 2024
bb3febb
GateSets are hashable
mtweiden Aug 16, 2024
3694184
Check for Gate sequences
mtweiden Aug 16, 2024
888e0b2
Tests for register_workflow
mtweiden Aug 16, 2024
f007e6b
pre-commit
mtweiden Aug 16, 2024
0a33d18
Added test for optimization_level=2
mtweiden Aug 16, 2024
a8b5e25
Renamed workflow_registry -> _workflow_registry
mtweiden Aug 20, 2024
ae23a05
Added clear_register
mtweiden Aug 20, 2024
6e59f4b
register -> registry
mtweiden Aug 20, 2024
92364bd
Fixed spelling error
mtweiden Aug 28, 2024
a74a8aa
Added is_workflow static function
mtweiden Aug 28, 2024
6ff5f9a
Workflow checking is done in Workflow construction
mtweiden Aug 28, 2024
3725b22
No default optimization level
mtweiden Aug 28, 2024
6b7da3c
Permutation robust Gateset hash
mtweiden Aug 28, 2024
99dfee1
Gateset hash test
mtweiden Aug 28, 2024
1e02804
Moved documentation
mtweiden Aug 28, 2024
b6219a9
Removed clear_registry function
mtweiden Aug 28, 2024
6e31654
Removed clear_registry function
mtweiden Aug 28, 2024
6e01010
Changed test
mtweiden Aug 28, 2024
46972c1
Fixed import global conflict
mtweiden Aug 28, 2024
433a108
MachineModels registered in _compile_registry
mtweiden Aug 28, 2024
74e53a0
only considers registered MachineModels
mtweiden Aug 28, 2024
cfa229d
Updated tests for _compile_registry
mtweiden Aug 28, 2024
47fa2ed
Removed unused import
mtweiden Aug 28, 2024
bae9fee
pre-commit
mtweiden Aug 28, 2024
1f94658
Fixed imports
mtweiden Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions bqskit/compiler/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
import numpy as np

from bqskit.compiler.compiler import Compiler
from bqskit.compiler.gateset import GateSet
from bqskit.compiler.machine import MachineModel
from bqskit.compiler.passdata import PassData
from bqskit.compiler.registry import _workflow_registry
from bqskit.compiler.workflow import Workflow
from bqskit.compiler.workflow import WorkflowLike
from bqskit.ir.circuit import Circuit
Expand Down Expand Up @@ -669,6 +671,18 @@ def build_workflow(
if model is None:
model = MachineModel(input.num_qudits, radixes=input.radixes)

# Use a registered workflow if model is found in the registry for a given
# optimization_level
for machine_or_gateset in _workflow_registry:
if isinstance(machine_or_gateset, GateSet):
gate_set = machine_or_gateset
else:
gate_set = machine_or_gateset.gate_set
gs_match = gate_set == model.gate_set
ol_found = optimization_level in _workflow_registry[machine_or_gateset]
if gs_match and ol_found:
return _workflow_registry[machine_or_gateset][optimization_level]
mtweiden marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(input, Circuit):
if input.num_qudits > max_synthesis_size:
if any(
Expand Down
4 changes: 4 additions & 0 deletions bqskit/compiler/gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,5 +231,9 @@ def __repr__(self) -> str:
"""Detailed representation of the GateSet."""
return self._gates.__repr__().replace('frozenset', 'GateSet')

def __hash__(self) -> int:
"""Hash of the GateSet."""
return self.__repr__().__hash__()
mtweiden marked this conversation as resolved.
Show resolved Hide resolved


GateSetLike = Union[GateSet, Iterable[Gate], Gate]
99 changes: 99 additions & 0 deletions bqskit/compiler/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
The _workflow_registry enables MachineModel or GateSet specific workflows to be
registered for used in the `bqskit.compile` method.

The _workflow_registry maps MachineModels a dictionary of Workflows which
are indexed by optimization level. This object should not be accessed directly
by the user, but instead through the `register_workflow` function.

Example:
model_t = SpecificMachineModel(num_qudits, radixes)
workflow = [QuickPartitioner(3), NewFangledOptimization()]
register_workflow(model_t, workflow, level)
...
new_circuit = compile(circuit, model_t, optimization_level=level)
"""
mtweiden marked this conversation as resolved.
Show resolved Hide resolved
from __future__ import annotations

import logging

from bqskit.compiler.basepass import BasePass
from bqskit.compiler.gateset import GateSet
from bqskit.compiler.gateset import GateSetLike
from bqskit.compiler.machine import MachineModel
from bqskit.compiler.workflow import Workflow
from bqskit.compiler.workflow import WorkflowLike
from bqskit.ir.gate import Gate

_logger = logging.getLogger(__name__)


_workflow_registry: dict[MachineModel | GateSet, dict[int, Workflow]] = {}
mtweiden marked this conversation as resolved.
Show resolved Hide resolved


def register_workflow(
machine_or_gateset: MachineModel | GateSetLike,
workflow: WorkflowLike,
optimization_level: int = 1,
) -> None:
"""
Register a workflow for a given machine model.

Args:
machine_or_gateset (MachineModel | GateSetLike): A MachineModel or
GateSetLike to register the workflow under. If a circuit is
compiled targeting this machine or gate set, the registered
workflow will be used.

workflow (list[BasePass]): The workflow or list of passes that whill
mtweiden marked this conversation as resolved.
Show resolved Hide resolved
be executed if the MachineModel in a call to `compile` matches
`machine`. If `machine` is already registered, a warning will be
logged.

optimization_level (Optional[int]): The optimization level with
which to register the workflow. If no level is provided, the
Workflow will be registered as level 1. (Default: 1)
mtweiden marked this conversation as resolved.
Show resolved Hide resolved

Raises:
TypeError: If `machine_or_gateset` is not a MachineModel or GateSet.

TypeError: If `workflow` is not a list of BasePass objects.
mtweiden marked this conversation as resolved.
Show resolved Hide resolved
"""
if not isinstance(machine_or_gateset, MachineModel):
if isinstance(machine_or_gateset, Gate):
machine_or_gateset = [machine_or_gateset]
if all(isinstance(g, Gate) for g in machine_or_gateset):
machine_or_gateset = GateSet(machine_or_gateset)
else:
m = '`machine_or_gateset` must be a MachineModel or '
m += f'GateSet, got {type(machine_or_gateset)}.'
raise TypeError(m)

workflow = Workflow(workflow)

for p in workflow:
if not isinstance(p, BasePass):
m = 'All elements of `workflow` must be BasePass objects. Got '
m += f'{type(p)}.'
raise TypeError(m)
mtweiden marked this conversation as resolved.
Show resolved Hide resolved

global _workflow_registry
new_workflow = {optimization_level: workflow}
if machine_or_gateset in _workflow_registry:
if optimization_level in _workflow_registry[machine_or_gateset]:
m = f'Overwritting workflow for {machine_or_gateset} '
m += f'at level {optimization_level}.'
_logger.warn(m)
mtweiden marked this conversation as resolved.
Show resolved Hide resolved
_workflow_registry[machine_or_gateset].update(new_workflow)
else:
_workflow_registry[machine_or_gateset] = new_workflow


def clear_registry() -> None:
mtweiden marked this conversation as resolved.
Show resolved Hide resolved
"""
Clear the workflow registry.

This will remove all registered workflows from the registry.
"""
global _workflow_registry
_workflow_registry.clear()
129 changes: 129 additions & 0 deletions tests/compiler/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""This file tests the register_workflow function."""
from __future__ import annotations

from itertools import combinations
from random import choice

import pytest
from numpy import allclose

from bqskit.compiler import compile
from bqskit.compiler.machine import MachineModel
from bqskit.compiler.registry import _workflow_registry
from bqskit.compiler.registry import clear_registry
from bqskit.compiler.registry import register_workflow
from bqskit.compiler.workflow import Workflow
from bqskit.compiler.workflow import WorkflowLike
from bqskit.ir import Circuit
from bqskit.ir import Gate
from bqskit.ir.gates import CZGate
from bqskit.ir.gates import HGate
from bqskit.ir.gates import RZGate
from bqskit.ir.gates import U3Gate
from bqskit.passes import QSearchSynthesisPass
from bqskit.passes import QuickPartitioner
from bqskit.passes import ScanningGateRemovalPass


def machine_match(mach_a: MachineModel, mach_b: MachineModel) -> bool:
if mach_a.num_qudits != mach_b.num_qudits:
return False
if mach_a.radixes != mach_b.radixes:
return False
if mach_a.coupling_graph != mach_b.coupling_graph:
return False
if mach_a.gate_set != mach_b.gate_set:
return False
return True


def unitary_match(unit_a: Circuit, unit_b: Circuit) -> bool:
return allclose(unit_a.get_unitary(), unit_b.get_unitary(), atol=1e-5)


def workflow_match(
workflow_a: WorkflowLike,
workflow_b: WorkflowLike,
) -> bool:
if not isinstance(workflow_a, Workflow):
workflow_a = Workflow(workflow_a)
if not isinstance(workflow_b, Workflow):
workflow_b = Workflow(workflow_b)
if len(workflow_a) != len(workflow_b):
return False
for a, b in zip(workflow_a, workflow_b):
if a.name != b.name:
return False
return True


def simple_circuit(num_qudits: int, gate_set: list[Gate]) -> Circuit:
circ = Circuit(num_qudits)
gate = choice(gate_set)
if gate.num_qudits == 1:
loc = choice(range(num_qudits))
else:
loc = choice(list(combinations(range(num_qudits), 2))) # type: ignore
gate_inv = gate.get_inverse()
circ.append_gate(gate, loc)
circ.append_gate(gate_inv, loc)
return circ


class TestRegisterWorkflow:

@pytest.fixture(autouse=True)
def setup(self) -> None:
# _workflow_registry.clear()
clear_registry()

def test_register_workflow(self) -> None:
assert _workflow_registry == {}
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QuickPartitioner(), ScanningGateRemovalPass()]
register_workflow(machine, workflow)
assert machine in _workflow_registry
assert 1 in _workflow_registry[machine]
assert workflow_match(_workflow_registry[machine][1], workflow)

def test_custom_compile_machine(self) -> None:
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QuickPartitioner(2)]
register_workflow(machine, workflow)
circuit = simple_circuit(num_qudits, gateset)
result = compile(circuit, machine)
assert unitary_match(result, circuit)
assert result.num_operations > 0
assert result.gate_counts != circuit.gate_counts
result.unfold_all()
assert result.gate_counts == circuit.gate_counts

def test_custom_compile_gateset(self) -> None:
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QuickPartitioner(2)]
register_workflow(gateset, workflow)
circuit = simple_circuit(num_qudits, gateset)
result = compile(circuit, machine)
assert unitary_match(result, circuit)
assert result.num_operations > 0
assert result.gate_counts != circuit.gate_counts
result.unfold_all()
assert result.gate_counts == circuit.gate_counts

def test_custom_opt_level(self) -> None:
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QSearchSynthesisPass()]
register_workflow(gateset, workflow, 2)
circuit = simple_circuit(num_qudits, gateset)
result = compile(circuit, machine, optimization_level=2)
assert unitary_match(result, circuit)
assert result.gate_counts != circuit.gate_counts
assert U3Gate() in result.gate_set
Loading