Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register workflows by target type #272

Merged
merged 6 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions bqskit/compiler/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from bqskit.compiler.compiler import Compiler
from bqskit.compiler.machine import MachineModel
from bqskit.compiler.passdata import PassData
from bqskit.compiler.registry import _compile_registry
from bqskit.compiler.registry import _compile_circuit_registry
from bqskit.compiler.registry import _compile_statemap_registry
from bqskit.compiler.registry import _compile_stateprep_registry
from bqskit.compiler.registry import _compile_unitary_registry
from bqskit.compiler.workflow import Workflow
from bqskit.compiler.workflow import WorkflowLike
from bqskit.ir.circuit import Circuit
Expand Down Expand Up @@ -622,8 +625,11 @@ def type_and_check_input(input: CompilationInputLike) -> CompilationInput:
if isinstance(typed_input, Circuit):
in_circuit = typed_input

elif isinstance(typed_input, UnitaryMatrix):
in_circuit = Circuit.from_unitary(typed_input)

else:
in_circuit = Circuit(1)
in_circuit = Circuit(typed_input.num_qudits, typed_input.radixes)

# Perform the compilation
out, data = compiler.compile(in_circuit, workflow, True)
Expand Down Expand Up @@ -669,12 +675,6 @@ def build_workflow(
if model is None:
model = MachineModel(input.num_qudits, radixes=input.radixes)

# Use a registered workflow if model is found in the registry for a given
# optimization_level
if model in _compile_registry:
if optimization_level in _compile_registry[model]:
return _compile_registry[model][optimization_level]

if isinstance(input, Circuit):
if input.num_qudits > max_synthesis_size:
if any(
Expand All @@ -691,6 +691,11 @@ def build_workflow(
'Unable to compile circuit with gate larger than'
' max_synthesis_size.\nConsider adjusting it.',
)
# Use a registered workflow if model is found in the circuit registry
# for a given optimization_level
if model in _compile_circuit_registry:
if optimization_level in _compile_circuit_registry[model]:
return _compile_circuit_registry[model][optimization_level]

return _circuit_workflow(
model,
Expand All @@ -708,6 +713,11 @@ def build_workflow(
'Unable to compile unitary with size larger than'
' max_synthesis_size.\nConsider adjusting it.',
)
# Use a registered workflow if model is found in the unitary registry
# for a given optimization_level
if model in _compile_unitary_registry:
if optimization_level in _compile_unitary_registry[model]:
return _compile_unitary_registry[model][optimization_level]

return _synthesis_workflow(
input,
Expand All @@ -726,6 +736,11 @@ def build_workflow(
'Unable to compile states with size larger than'
' max_synthesis_size.\nConsider adjusting it.',
)
# Use a registered workflow if model is found in the stateprep registry
# for a given optimization_level
if model in _compile_stateprep_registry:
if optimization_level in _compile_stateprep_registry[model]:
return _compile_stateprep_registry[model][optimization_level]

return _stateprep_workflow(
input,
Expand All @@ -744,6 +759,11 @@ def build_workflow(
'Unable to compile state systems with size larger than'
' max_synthesis_size.\nConsider adjusting it.',
)
# Use a registered workflow if model is found in the statemap registry
# for a given optimization_level
if model in _compile_statemap_registry:
if optimization_level in _compile_statemap_registry[model]:
return _compile_statemap_registry[model][optimization_level]

return _statemap_workflow(
input,
Expand Down
40 changes: 34 additions & 6 deletions bqskit/compiler/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
from bqskit.compiler.workflow import WorkflowLike


_compile_registry: dict[MachineModel, dict[int, Workflow]] = {}
_compile_circuit_registry: dict[MachineModel, dict[int, Workflow]] = {}
_compile_unitary_registry: dict[MachineModel, dict[int, Workflow]] = {}
_compile_stateprep_registry: dict[MachineModel, dict[int, Workflow]] = {}
_compile_statemap_registry: dict[MachineModel, dict[int, Workflow]] = {}


def register_workflow(
key: MachineModel,
workflow: WorkflowLike,
optimization_level: int,
target_type: str,
) -> None:
"""
Register a workflow for a given MachineModel.
Expand All @@ -34,10 +38,13 @@ def register_workflow(
be executed if the MachineModel in a call to `compile` matches
`key`. If `key` is already registered, a warning will be logged.

optimization_level ptional[int): The optimization level with which
optimization_level (Optional[int]): The optimization level with which
to register the workflow. If no level is provided, the Workflow
will be registered as level 1.
mtweiden marked this conversation as resolved.
Show resolved Hide resolved

target_type (str): Register a workflow for targets of this type. Must
be 'circuit', 'unitary', 'stateprep', or 'statemap'.
edyounis marked this conversation as resolved.
Show resolved Hide resolved

Example:
model_t = SpecificMachineModel(num_qudits, radixes)
workflow = [QuickPartitioner(3), NewFangledOptimization()]
Expand All @@ -47,17 +54,38 @@ def register_workflow(

Raises:
Warning: If a workflow for a given optimization_level is overwritten.

ValueError: If `target_type` is not 'circuit', 'unitary', 'stateprep',
or 'statemap'.
"""
if target_type not in ['circuit', 'unitary', 'stateprep', 'statemap']:
m = 'target_type must be "circuit", "unitary", "stateprep", or '
m += f'"statemap", got {target_type}.'
raise ValueError(m)

if target_type == 'circuit':
global _compile_circuit_registry
_compile_registry = _compile_circuit_registry
elif target_type == 'unitary':
global _compile_unitary_registry
_compile_registry = _compile_unitary_registry
elif target_type == 'stateprep':
global _compile_stateprep_registry
_compile_registry = _compile_stateprep_registry
else:
global _compile_statemap_registry
_compile_registry = _compile_statemap_registry

workflow = Workflow(workflow)

global _compile_registry
new_workflow = {optimization_level: workflow}
if key in _compile_registry:
if optimization_level in _compile_registry[key]:
m = f'Overwritting workflow for {key} at level '
m += f'{optimization_level}. If multiple Namespace packages are '
m += 'installed, ensure that their __init__.py files do not '
m += 'attempt to overwrite the same default Workflows.'
m += f'{optimization_level} for target type {target_type}.'
m += 'If multiple Namespace packages are installed, ensure'
m += 'that their __init__.py files do not attempt to'
m += 'overwrite the same default Workflows.'
warnings.warn(m)
_compile_registry[key].update(new_workflow)
else:
Expand Down
63 changes: 48 additions & 15 deletions tests/compiler/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

from bqskit.compiler.compile import compile
from bqskit.compiler.machine import MachineModel
from bqskit.compiler.registry import _compile_registry
from bqskit.compiler.registry import _compile_circuit_registry
from bqskit.compiler.registry import _compile_statemap_registry
from bqskit.compiler.registry import _compile_stateprep_registry
from bqskit.compiler.registry import _compile_unitary_registry
from bqskit.compiler.registry import register_workflow
from bqskit.compiler.workflow import Workflow
from bqskit.compiler.workflow import WorkflowLike
Expand Down Expand Up @@ -74,28 +77,58 @@ class TestRegisterWorkflow:
@pytest.fixture(autouse=True)
def setup(self) -> None:
# global _compile_registry
_compile_registry.clear()
_compile_circuit_registry.clear()
_compile_unitary_registry.clear()
_compile_statemap_registry.clear()
_compile_stateprep_registry.clear()

def test_register_workflow(self) -> None:
global _compile_registry
assert _compile_registry == {}
global _compile_circuit_registry
global _compile_unitary_registry
global _compile_statemap_registry
global _compile_stateprep_registry
assert _compile_circuit_registry == {}
assert _compile_unitary_registry == {}
assert _compile_statemap_registry == {}
assert _compile_stateprep_registry == {}
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QuickPartitioner(), ScanningGateRemovalPass()]
register_workflow(machine, workflow, 1)
assert machine in _compile_registry
assert 1 in _compile_registry[machine]
assert workflow_match(_compile_registry[machine][1], workflow)
circuit_workflow = [QuickPartitioner(), ScanningGateRemovalPass()]
other_workflow = [QuickPartitioner(), QSearchSynthesisPass()]
register_workflow(machine, circuit_workflow, 1, 'circuit')
register_workflow(machine, other_workflow, 1, 'unitary')
register_workflow(machine, other_workflow, 1, 'statemap')
register_workflow(machine, other_workflow, 1, 'stateprep')
assert machine in _compile_circuit_registry
assert 1 in _compile_circuit_registry[machine]
assert workflow_match(
_compile_circuit_registry[machine][1], circuit_workflow,
)
assert machine in _compile_unitary_registry
assert 1 in _compile_unitary_registry[machine]
assert workflow_match(
_compile_unitary_registry[machine][1], other_workflow,
)
assert machine in _compile_statemap_registry
assert 1 in _compile_statemap_registry[machine]
assert workflow_match(
_compile_statemap_registry[machine][1], other_workflow,
)
assert machine in _compile_stateprep_registry
assert 1 in _compile_stateprep_registry[machine]
assert workflow_match(
_compile_stateprep_registry[machine][1], other_workflow,
)

def test_custom_compile_machine(self) -> None:
global _compile_registry
assert _compile_registry == {}
global _compile_circuit_registry
assert _compile_circuit_registry == {}
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QuickPartitioner(2)]
register_workflow(machine, workflow, 1)
register_workflow(machine, workflow, 1, 'circuit')
circuit = simple_circuit(num_qudits, gateset)
result = compile(circuit, machine)
assert unitary_match(result, circuit)
Expand All @@ -105,13 +138,13 @@ def test_custom_compile_machine(self) -> None:
assert result.gate_counts == circuit.gate_counts

def test_custom_opt_level(self) -> None:
global _compile_registry
assert _compile_registry == {}
global _compile_circuit_registry
assert _compile_circuit_registry == {}
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QSearchSynthesisPass()]
register_workflow(machine, workflow, 2)
register_workflow(machine, workflow, 2, 'circuit')
circuit = simple_circuit(num_qudits, gateset)
result = compile(circuit, machine, optimization_level=2)
assert unitary_match(result, circuit)
Expand Down
Loading