Skip to content

Commit

Permalink
Merge pull request #272 from BQSKit/typed-registry
Browse files Browse the repository at this point in the history
Register workflows by target type
  • Loading branch information
edyounis authored Sep 11, 2024
2 parents 264ded8 + 9629e3c commit e4b212e
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 31 deletions.
57 changes: 50 additions & 7 deletions bqskit/compiler/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from bqskit.compiler.compiler import Compiler
from bqskit.compiler.machine import MachineModel
from bqskit.compiler.passdata import PassData
from bqskit.compiler.registry import _compile_registry
from bqskit.compiler.registry import _compile_circuit_registry
from bqskit.compiler.registry import _compile_statemap_registry
from bqskit.compiler.registry import _compile_stateprep_registry
from bqskit.compiler.registry import _compile_unitary_registry
from bqskit.compiler.registry import model_registered_target_types
from bqskit.compiler.workflow import Workflow
from bqskit.compiler.workflow import WorkflowLike
from bqskit.ir.circuit import Circuit
Expand Down Expand Up @@ -622,8 +626,11 @@ def type_and_check_input(input: CompilationInputLike) -> CompilationInput:
if isinstance(typed_input, Circuit):
in_circuit = typed_input

elif isinstance(typed_input, UnitaryMatrix):
in_circuit = Circuit.from_unitary(typed_input)

else:
in_circuit = Circuit(1)
in_circuit = Circuit(typed_input.num_qudits, typed_input.radixes)

# Perform the compilation
out, data = compiler.compile(in_circuit, workflow, True)
Expand Down Expand Up @@ -669,11 +676,7 @@ def build_workflow(
if model is None:
model = MachineModel(input.num_qudits, radixes=input.radixes)

# Use a registered workflow if model is found in the registry for a given
# optimization_level
if model in _compile_registry:
if optimization_level in _compile_registry[model]:
return _compile_registry[model][optimization_level]
model_registered_types = model_registered_target_types(model)

if isinstance(input, Circuit):
if input.num_qudits > max_synthesis_size:
Expand All @@ -691,6 +694,16 @@ def build_workflow(
'Unable to compile circuit with gate larger than'
' max_synthesis_size.\nConsider adjusting it.',
)
# Use a registered workflow if model is found in the circuit registry
# for a given optimization_level
if model in _compile_circuit_registry:
if optimization_level in _compile_circuit_registry[model]:
return _compile_circuit_registry[model][optimization_level]
elif len(model_registered_types) > 0:
m = f'MachineModel {model} is registered for inputs of type in '
m += f'{model_registered_types}, but input is {type(input)}. '
m += f'You may need to register a Workflow for type {type(input)}.'
warnings.warn(m)

return _circuit_workflow(
model,
Expand All @@ -708,6 +721,16 @@ def build_workflow(
'Unable to compile unitary with size larger than'
' max_synthesis_size.\nConsider adjusting it.',
)
# Use a registered workflow if model is found in the unitary registry
# for a given optimization_level
if model in _compile_unitary_registry:
if optimization_level in _compile_unitary_registry[model]:
return _compile_unitary_registry[model][optimization_level]
elif len(model_registered_types) > 0:
m = f'MachineModel {model} is registered for inputs of type in '
m += f'{model_registered_types}, but input is {type(input)}. '
m += f'You may need to register a Workflow for type {type(input)}.'
warnings.warn(m)

return _synthesis_workflow(
input,
Expand All @@ -726,6 +749,16 @@ def build_workflow(
'Unable to compile states with size larger than'
' max_synthesis_size.\nConsider adjusting it.',
)
# Use a registered workflow if model is found in the stateprep registry
# for a given optimization_level
if model in _compile_stateprep_registry:
if optimization_level in _compile_stateprep_registry[model]:
return _compile_stateprep_registry[model][optimization_level]
elif len(model_registered_types) > 0:
m = f'MachineModel {model} is registered for inputs of type in '
m += f'{model_registered_types}, but input is {type(input)}. '
m += f'You may need to register a Workflow for type {type(input)}.'
warnings.warn(m)

return _stateprep_workflow(
input,
Expand All @@ -744,6 +777,16 @@ def build_workflow(
'Unable to compile state systems with size larger than'
' max_synthesis_size.\nConsider adjusting it.',
)
# Use a registered workflow if model is found in the statemap registry
# for a given optimization_level
if model in _compile_statemap_registry:
if optimization_level in _compile_statemap_registry[model]:
return _compile_statemap_registry[model][optimization_level]
elif len(model_registered_types) > 0:
m = f'MachineModel {model} is registered for inputs of type in '
m += f'{model_registered_types}, but input is {type(input)}. '
m += f'You may need to register a Workflow for type {type(input)}.'
warnings.warn(m)

return _statemap_workflow(
input,
Expand Down
73 changes: 64 additions & 9 deletions bqskit/compiler/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,44 @@
from bqskit.compiler.workflow import WorkflowLike


_compile_registry: dict[MachineModel, dict[int, Workflow]] = {}
_compile_circuit_registry: dict[MachineModel, dict[int, Workflow]] = {}
_compile_unitary_registry: dict[MachineModel, dict[int, Workflow]] = {}
_compile_stateprep_registry: dict[MachineModel, dict[int, Workflow]] = {}
_compile_statemap_registry: dict[MachineModel, dict[int, Workflow]] = {}


def model_registered_target_types(key: MachineModel) -> list[str]:
"""
Return a list of target_types for which key is registered.
Args:
key (MachineModel): A MachineModel to check for.
Returns:
(list[str]): If `key` has been registered in any of the registry, the
name of that target type will be contained in this list.
"""
global _compile_circuit_registry
global _compile_unitary_registry
global _compile_stateprep_registry
global _compile_statemap_registry
registered_types = []
if key in _compile_circuit_registry:
registered_types.append('circuit')
if key in _compile_unitary_registry:
registered_types.append('unitary')
if key in _compile_stateprep_registry:
registered_types.append('stateprep')
if key in _compile_statemap_registry:
registered_types.append('statemap')
return registered_types


def register_workflow(
key: MachineModel,
workflow: WorkflowLike,
optimization_level: int,
target_type: str = 'circuit',
) -> None:
"""
Register a workflow for a given MachineModel.
Expand All @@ -34,30 +65,54 @@ def register_workflow(
be executed if the MachineModel in a call to `compile` matches
`key`. If `key` is already registered, a warning will be logged.
optimization_level ptional[int): The optimization level with which
to register the workflow. If no level is provided, the Workflow
will be registered as level 1.
optimization_level (int): The optimization level with which to
register the workflow.
target_type (str): Register a workflow for targets of this type. Must
be 'circuit', 'unitary', 'stateprep', or 'statemap'.
(Default: 'circuit')
Example:
model_t = SpecificMachineModel(num_qudits, radixes)
workflow = [QuickPartitioner(3), NewFangledOptimization()]
register_workflow(model_t, workflow, level)
register_workflow(model_t, workflow, level, 'circuit')
...
new_circuit = compile(circuit, model_t, optimization_level=level)
Raises:
Warning: If a workflow for a given optimization_level is overwritten.
ValueError: If `target_type` is not 'circuit', 'unitary', 'stateprep',
or 'statemap'.
"""
if target_type not in ['circuit', 'unitary', 'stateprep', 'statemap']:
m = 'target_type must be "circuit", "unitary", "stateprep", or '
m += f'"statemap", got {target_type}.'
raise ValueError(m)

if target_type == 'circuit':
global _compile_circuit_registry
_compile_registry = _compile_circuit_registry
elif target_type == 'unitary':
global _compile_unitary_registry
_compile_registry = _compile_unitary_registry
elif target_type == 'stateprep':
global _compile_stateprep_registry
_compile_registry = _compile_stateprep_registry
else:
global _compile_statemap_registry
_compile_registry = _compile_statemap_registry

workflow = Workflow(workflow)

global _compile_registry
new_workflow = {optimization_level: workflow}
if key in _compile_registry:
if optimization_level in _compile_registry[key]:
m = f'Overwritting workflow for {key} at level '
m += f'{optimization_level}. If multiple Namespace packages are '
m += 'installed, ensure that their __init__.py files do not '
m += 'attempt to overwrite the same default Workflows.'
m += f'{optimization_level} for target type {target_type}.'
m += 'If multiple Namespace packages are installed, ensure'
m += 'that their __init__.py files do not attempt to'
m += 'overwrite the same default Workflows.'
warnings.warn(m)
_compile_registry[key].update(new_workflow)
else:
Expand Down
63 changes: 48 additions & 15 deletions tests/compiler/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

from bqskit.compiler.compile import compile
from bqskit.compiler.machine import MachineModel
from bqskit.compiler.registry import _compile_registry
from bqskit.compiler.registry import _compile_circuit_registry
from bqskit.compiler.registry import _compile_statemap_registry
from bqskit.compiler.registry import _compile_stateprep_registry
from bqskit.compiler.registry import _compile_unitary_registry
from bqskit.compiler.registry import register_workflow
from bqskit.compiler.workflow import Workflow
from bqskit.compiler.workflow import WorkflowLike
Expand Down Expand Up @@ -74,28 +77,58 @@ class TestRegisterWorkflow:
@pytest.fixture(autouse=True)
def setup(self) -> None:
# global _compile_registry
_compile_registry.clear()
_compile_circuit_registry.clear()
_compile_unitary_registry.clear()
_compile_statemap_registry.clear()
_compile_stateprep_registry.clear()

def test_register_workflow(self) -> None:
global _compile_registry
assert _compile_registry == {}
global _compile_circuit_registry
global _compile_unitary_registry
global _compile_statemap_registry
global _compile_stateprep_registry
assert _compile_circuit_registry == {}
assert _compile_unitary_registry == {}
assert _compile_statemap_registry == {}
assert _compile_stateprep_registry == {}
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QuickPartitioner(), ScanningGateRemovalPass()]
register_workflow(machine, workflow, 1)
assert machine in _compile_registry
assert 1 in _compile_registry[machine]
assert workflow_match(_compile_registry[machine][1], workflow)
circuit_workflow = [QuickPartitioner(), ScanningGateRemovalPass()]
other_workflow = [QuickPartitioner(), QSearchSynthesisPass()]
register_workflow(machine, circuit_workflow, 1, 'circuit')
register_workflow(machine, other_workflow, 1, 'unitary')
register_workflow(machine, other_workflow, 1, 'statemap')
register_workflow(machine, other_workflow, 1, 'stateprep')
assert machine in _compile_circuit_registry
assert 1 in _compile_circuit_registry[machine]
assert workflow_match(
_compile_circuit_registry[machine][1], circuit_workflow,
)
assert machine in _compile_unitary_registry
assert 1 in _compile_unitary_registry[machine]
assert workflow_match(
_compile_unitary_registry[machine][1], other_workflow,
)
assert machine in _compile_statemap_registry
assert 1 in _compile_statemap_registry[machine]
assert workflow_match(
_compile_statemap_registry[machine][1], other_workflow,
)
assert machine in _compile_stateprep_registry
assert 1 in _compile_stateprep_registry[machine]
assert workflow_match(
_compile_stateprep_registry[machine][1], other_workflow,
)

def test_custom_compile_machine(self) -> None:
global _compile_registry
assert _compile_registry == {}
global _compile_circuit_registry
assert _compile_circuit_registry == {}
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QuickPartitioner(2)]
register_workflow(machine, workflow, 1)
register_workflow(machine, workflow, 1, 'circuit')
circuit = simple_circuit(num_qudits, gateset)
result = compile(circuit, machine)
assert unitary_match(result, circuit)
Expand All @@ -105,13 +138,13 @@ def test_custom_compile_machine(self) -> None:
assert result.gate_counts == circuit.gate_counts

def test_custom_opt_level(self) -> None:
global _compile_registry
assert _compile_registry == {}
global _compile_circuit_registry
assert _compile_circuit_registry == {}
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QSearchSynthesisPass()]
register_workflow(machine, workflow, 2)
register_workflow(machine, workflow, 2, 'circuit')
circuit = simple_circuit(num_qudits, gateset)
result = compile(circuit, machine, optimization_level=2)
assert unitary_match(result, circuit)
Expand Down

0 comments on commit e4b212e

Please sign in to comment.