diff --git a/bqskit/compiler/compile.py b/bqskit/compiler/compile.py index 4824c42ad..cac057bd3 100644 --- a/bqskit/compiler/compile.py +++ b/bqskit/compiler/compile.py @@ -15,6 +15,7 @@ from bqskit.compiler.compiler import Compiler from bqskit.compiler.machine import MachineModel from bqskit.compiler.passdata import PassData +from bqskit.compiler.registry import _compile_registry from bqskit.compiler.workflow import Workflow from bqskit.compiler.workflow import WorkflowLike from bqskit.ir.circuit import Circuit @@ -669,6 +670,12 @@ def build_workflow( if model is None: model = MachineModel(input.num_qudits, radixes=input.radixes) + # Use a registered workflow if model is found in the registry for a given + # optimization_level + if model in _compile_registry: + if optimization_level in _compile_registry[model]: + return _compile_registry[model][optimization_level] + if isinstance(input, Circuit): if input.num_qudits > max_synthesis_size: if any( diff --git a/bqskit/compiler/gateset.py b/bqskit/compiler/gateset.py index d8111ec4a..be50ce692 100644 --- a/bqskit/compiler/gateset.py +++ b/bqskit/compiler/gateset.py @@ -231,5 +231,9 @@ def __repr__(self) -> str: """Detailed representation of the GateSet.""" return self._gates.__repr__().replace('frozenset', 'GateSet') + def __hash__(self) -> int: + """Hash of the GateSet.""" + return hash(tuple(sorted([g.name for g in self._gates]))) + GateSetLike = Union[GateSet, Iterable[Gate], Gate] diff --git a/bqskit/compiler/registry.py b/bqskit/compiler/registry.py new file mode 100644 index 000000000..b3cb1b4d2 --- /dev/null +++ b/bqskit/compiler/registry.py @@ -0,0 +1,64 @@ +"""Register MachineModel specific default workflows.""" +from __future__ import annotations + +import warnings + +from bqskit.compiler.machine import MachineModel +from bqskit.compiler.workflow import Workflow +from bqskit.compiler.workflow import WorkflowLike + + +_compile_registry: dict[MachineModel, dict[int, Workflow]] = {} + + +def register_workflow( + key: MachineModel, + workflow: WorkflowLike, + optimization_level: int, +) -> None: + """ + Register a workflow for a given MachineModel. + + The _compile_registry enables MachineModel specific workflows to be + registered for use in the `bqskit.compile` method. _compile_registry maps + MachineModels a dictionary of Workflows which are indexed by optimization + level. This object should not be accessed directly by the user, but + instead through the `register_workflow` function. + + Args: + key (MachineModel): A MachineModel to register the workflow under. + If a circuit is compiled targeting this machine or gate set, the + registered workflow will be used. + + workflow (list[BasePass]): The workflow or list of passes that will + be executed if the MachineModel in a call to `compile` matches + `key`. If `key` is already registered, a warning will be logged. + + optimization_level ptional[int): The optimization level with which + to register the workflow. If no level is provided, the Workflow + will be registered as level 1. + + Example: + model_t = SpecificMachineModel(num_qudits, radixes) + workflow = [QuickPartitioner(3), NewFangledOptimization()] + register_workflow(model_t, workflow, level) + ... + new_circuit = compile(circuit, model_t, optimization_level=level) + + Raises: + Warning: If a workflow for a given optimization_level is overwritten. + """ + workflow = Workflow(workflow) + + global _compile_registry + new_workflow = {optimization_level: workflow} + if key in _compile_registry: + if optimization_level in _compile_registry[key]: + m = f'Overwritting workflow for {key} at level ' + m += f'{optimization_level}. If multiple Namespace packages are ' + m += 'installed, ensure that their __init__.py files do not ' + m += 'attempt to overwrite the same default Workflows.' + warnings.warn(m) + _compile_registry[key].update(new_workflow) + else: + _compile_registry[key] = new_workflow diff --git a/bqskit/compiler/workflow.py b/bqskit/compiler/workflow.py index 6134d07aa..0771ce6d5 100644 --- a/bqskit/compiler/workflow.py +++ b/bqskit/compiler/workflow.py @@ -87,6 +87,12 @@ def name(self) -> str: """The name of the pass.""" return self._name or self.__class__.__name__ + @staticmethod + def is_workflow(workflow: WorkflowLike) -> bool: + if not is_iterable(workflow): + return isinstance(workflow, BasePass) + return all(isinstance(p, BasePass) for p in workflow) + def __str__(self) -> str: name_seq = f'Workflow: {self.name}\n\t' pass_strs = [ diff --git a/tests/compiler/test_gateset.py b/tests/compiler/test_gateset.py index 009003a95..4f89b8fca 100644 --- a/tests/compiler/test_gateset.py +++ b/tests/compiler/test_gateset.py @@ -522,3 +522,16 @@ def test_gate_set_repr() -> None: repr(gate_set) == 'GateSet({CNOTGate, U3Gate})' or repr(gate_set) == 'GateSet({U3Gate, CNOTGate})' ) + + +def test_gate_set_hash() -> None: + gate_set_1 = GateSet({CNOTGate(), U3Gate()}) + gate_set_2 = GateSet({U3Gate(), CNOTGate()}) + gate_set_3 = GateSet({U3Gate(), CNOTGate(), RZGate()}) + + h1 = hash(gate_set_1) + h2 = hash(gate_set_2) + h3 = hash(gate_set_3) + + assert h1 == h2 + assert h1 != h3 diff --git a/tests/compiler/test_registry.py b/tests/compiler/test_registry.py new file mode 100644 index 000000000..6371211c9 --- /dev/null +++ b/tests/compiler/test_registry.py @@ -0,0 +1,119 @@ +"""This file tests the register_workflow function.""" +from __future__ import annotations + +from itertools import combinations +from random import choice + +import pytest +from numpy import allclose + +from bqskit.compiler.compile import compile +from bqskit.compiler.machine import MachineModel +from bqskit.compiler.registry import _compile_registry +from bqskit.compiler.registry import register_workflow +from bqskit.compiler.workflow import Workflow +from bqskit.compiler.workflow import WorkflowLike +from bqskit.ir import Circuit +from bqskit.ir import Gate +from bqskit.ir.gates import CZGate +from bqskit.ir.gates import HGate +from bqskit.ir.gates import RZGate +from bqskit.ir.gates import U3Gate +from bqskit.passes import QSearchSynthesisPass +from bqskit.passes import QuickPartitioner +from bqskit.passes import ScanningGateRemovalPass + + +def machine_match(mach_a: MachineModel, mach_b: MachineModel) -> bool: + if mach_a.num_qudits != mach_b.num_qudits: + return False + if mach_a.radixes != mach_b.radixes: + return False + if mach_a.coupling_graph != mach_b.coupling_graph: + return False + if mach_a.gate_set != mach_b.gate_set: + return False + return True + + +def unitary_match(unit_a: Circuit, unit_b: Circuit) -> bool: + return allclose(unit_a.get_unitary(), unit_b.get_unitary(), atol=1e-5) + + +def workflow_match( + workflow_a: WorkflowLike, + workflow_b: WorkflowLike, +) -> bool: + if not isinstance(workflow_a, Workflow): + workflow_a = Workflow(workflow_a) + if not isinstance(workflow_b, Workflow): + workflow_b = Workflow(workflow_b) + if len(workflow_a) != len(workflow_b): + return False + for a, b in zip(workflow_a, workflow_b): + if a.name != b.name: + return False + return True + + +def simple_circuit(num_qudits: int, gate_set: list[Gate]) -> Circuit: + circ = Circuit(num_qudits) + gate = choice(gate_set) + if gate.num_qudits == 1: + loc = choice(range(num_qudits)) + else: + loc = choice(list(combinations(range(num_qudits), 2))) # type: ignore + gate_inv = gate.get_inverse() + circ.append_gate(gate, loc) + circ.append_gate(gate_inv, loc) + return circ + + +class TestRegisterWorkflow: + + @pytest.fixture(autouse=True) + def setup(self) -> None: + # global _compile_registry + _compile_registry.clear() + + def test_register_workflow(self) -> None: + global _compile_registry + assert _compile_registry == {} + gateset = [CZGate(), HGate(), RZGate()] + num_qudits = 3 + machine = MachineModel(num_qudits, gate_set=gateset) + workflow = [QuickPartitioner(), ScanningGateRemovalPass()] + register_workflow(machine, workflow, 1) + assert machine in _compile_registry + assert 1 in _compile_registry[machine] + assert workflow_match(_compile_registry[machine][1], workflow) + + def test_custom_compile_machine(self) -> None: + global _compile_registry + assert _compile_registry == {} + gateset = [CZGate(), HGate(), RZGate()] + num_qudits = 3 + machine = MachineModel(num_qudits, gate_set=gateset) + workflow = [QuickPartitioner(2)] + register_workflow(machine, workflow, 1) + circuit = simple_circuit(num_qudits, gateset) + result = compile(circuit, machine) + assert unitary_match(result, circuit) + assert result.num_operations > 0 + assert result.gate_counts != circuit.gate_counts + result.unfold_all() + assert result.gate_counts == circuit.gate_counts + + def test_custom_opt_level(self) -> None: + global _compile_registry + assert _compile_registry == {} + gateset = [CZGate(), HGate(), RZGate()] + num_qudits = 3 + machine = MachineModel(num_qudits, gate_set=gateset) + workflow = [QSearchSynthesisPass()] + register_workflow(machine, workflow, 2) + circuit = simple_circuit(num_qudits, gateset) + result = compile(circuit, machine, optimization_level=2) + assert unitary_match(result, circuit) + assert result.gate_counts != circuit.gate_counts + assert U3Gate() in result.gate_set