Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register workflows by target type #272

Merged
merged 6 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 50 additions & 7 deletions bqskit/compiler/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from bqskit.compiler.compiler import Compiler
from bqskit.compiler.machine import MachineModel
from bqskit.compiler.passdata import PassData
from bqskit.compiler.registry import _compile_registry
from bqskit.compiler.registry import _compile_circuit_registry
from bqskit.compiler.registry import _compile_statemap_registry
from bqskit.compiler.registry import _compile_stateprep_registry
from bqskit.compiler.registry import _compile_unitary_registry
from bqskit.compiler.registry import model_registered_target_types
from bqskit.compiler.workflow import Workflow
from bqskit.compiler.workflow import WorkflowLike
from bqskit.ir.circuit import Circuit
Expand Down Expand Up @@ -622,8 +626,11 @@ def type_and_check_input(input: CompilationInputLike) -> CompilationInput:
if isinstance(typed_input, Circuit):
in_circuit = typed_input

elif isinstance(typed_input, UnitaryMatrix):
in_circuit = Circuit.from_unitary(typed_input)

else:
in_circuit = Circuit(1)
in_circuit = Circuit(typed_input.num_qudits, typed_input.radixes)

# Perform the compilation
out, data = compiler.compile(in_circuit, workflow, True)
Expand Down Expand Up @@ -669,11 +676,7 @@ def build_workflow(
if model is None:
model = MachineModel(input.num_qudits, radixes=input.radixes)

# Use a registered workflow if model is found in the registry for a given
# optimization_level
if model in _compile_registry:
if optimization_level in _compile_registry[model]:
return _compile_registry[model][optimization_level]
model_registered_types = model_registered_target_types(model)

if isinstance(input, Circuit):
if input.num_qudits > max_synthesis_size:
Expand All @@ -691,6 +694,16 @@ def build_workflow(
'Unable to compile circuit with gate larger than'
' max_synthesis_size.\nConsider adjusting it.',
)
# Use a registered workflow if model is found in the circuit registry
# for a given optimization_level
if model in _compile_circuit_registry:
if optimization_level in _compile_circuit_registry[model]:
return _compile_circuit_registry[model][optimization_level]
elif len(model_registered_types) > 0:
m = f'MachineModel {model} is registered for inputs of type in '
m += f'{model_registered_types}, but input is {type(input)}. '
m += f'You may need to register a Workflow for type {type(input)}.'
warnings.warn(m)

return _circuit_workflow(
model,
Expand All @@ -708,6 +721,16 @@ def build_workflow(
'Unable to compile unitary with size larger than'
' max_synthesis_size.\nConsider adjusting it.',
)
# Use a registered workflow if model is found in the unitary registry
# for a given optimization_level
if model in _compile_unitary_registry:
if optimization_level in _compile_unitary_registry[model]:
return _compile_unitary_registry[model][optimization_level]
elif len(model_registered_types) > 0:
m = f'MachineModel {model} is registered for inputs of type in '
m += f'{model_registered_types}, but input is {type(input)}. '
m += f'You may need to register a Workflow for type {type(input)}.'
warnings.warn(m)

return _synthesis_workflow(
input,
Expand All @@ -726,6 +749,16 @@ def build_workflow(
'Unable to compile states with size larger than'
' max_synthesis_size.\nConsider adjusting it.',
)
# Use a registered workflow if model is found in the stateprep registry
# for a given optimization_level
if model in _compile_stateprep_registry:
if optimization_level in _compile_stateprep_registry[model]:
return _compile_stateprep_registry[model][optimization_level]
elif len(model_registered_types) > 0:
m = f'MachineModel {model} is registered for inputs of type in '
m += f'{model_registered_types}, but input is {type(input)}. '
m += f'You may need to register a Workflow for type {type(input)}.'
warnings.warn(m)

return _stateprep_workflow(
input,
Expand All @@ -744,6 +777,16 @@ def build_workflow(
'Unable to compile state systems with size larger than'
' max_synthesis_size.\nConsider adjusting it.',
)
# Use a registered workflow if model is found in the statemap registry
# for a given optimization_level
if model in _compile_statemap_registry:
if optimization_level in _compile_statemap_registry[model]:
return _compile_statemap_registry[model][optimization_level]
elif len(model_registered_types) > 0:
m = f'MachineModel {model} is registered for inputs of type in '
m += f'{model_registered_types}, but input is {type(input)}. '
m += f'You may need to register a Workflow for type {type(input)}.'
warnings.warn(m)

return _statemap_workflow(
input,
Expand Down
73 changes: 64 additions & 9 deletions bqskit/compiler/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,44 @@
from bqskit.compiler.workflow import WorkflowLike


_compile_registry: dict[MachineModel, dict[int, Workflow]] = {}
_compile_circuit_registry: dict[MachineModel, dict[int, Workflow]] = {}
_compile_unitary_registry: dict[MachineModel, dict[int, Workflow]] = {}
_compile_stateprep_registry: dict[MachineModel, dict[int, Workflow]] = {}
_compile_statemap_registry: dict[MachineModel, dict[int, Workflow]] = {}


def model_registered_target_types(key: MachineModel) -> list[str]:
"""
Return a list of target_types for which key is registered.

Args:
key (MachineModel): A MachineModel to check for.

Returns:
(list[str]): If `key` has been registered in any of the registry, the
name of that target type will be contained in this list.
"""
global _compile_circuit_registry
global _compile_unitary_registry
global _compile_stateprep_registry
global _compile_statemap_registry
registered_types = []
if key in _compile_circuit_registry:
registered_types.append('circuit')
if key in _compile_unitary_registry:
registered_types.append('unitary')
if key in _compile_stateprep_registry:
registered_types.append('stateprep')
if key in _compile_statemap_registry:
registered_types.append('statemap')
return registered_types


def register_workflow(
key: MachineModel,
workflow: WorkflowLike,
optimization_level: int,
target_type: str = 'circuit',
) -> None:
"""
Register a workflow for a given MachineModel.
Expand All @@ -34,30 +65,54 @@ def register_workflow(
be executed if the MachineModel in a call to `compile` matches
`key`. If `key` is already registered, a warning will be logged.

optimization_level ptional[int): The optimization level with which
to register the workflow. If no level is provided, the Workflow
will be registered as level 1.
optimization_level (int): The optimization level with which to
register the workflow.

target_type (str): Register a workflow for targets of this type. Must
be 'circuit', 'unitary', 'stateprep', or 'statemap'.
edyounis marked this conversation as resolved.
Show resolved Hide resolved
(Default: 'circuit')

Example:
model_t = SpecificMachineModel(num_qudits, radixes)
workflow = [QuickPartitioner(3), NewFangledOptimization()]
register_workflow(model_t, workflow, level)
register_workflow(model_t, workflow, level, 'circuit')
...
new_circuit = compile(circuit, model_t, optimization_level=level)

Raises:
Warning: If a workflow for a given optimization_level is overwritten.

ValueError: If `target_type` is not 'circuit', 'unitary', 'stateprep',
or 'statemap'.
"""
if target_type not in ['circuit', 'unitary', 'stateprep', 'statemap']:
m = 'target_type must be "circuit", "unitary", "stateprep", or '
m += f'"statemap", got {target_type}.'
raise ValueError(m)

if target_type == 'circuit':
global _compile_circuit_registry
_compile_registry = _compile_circuit_registry
elif target_type == 'unitary':
global _compile_unitary_registry
_compile_registry = _compile_unitary_registry
elif target_type == 'stateprep':
global _compile_stateprep_registry
_compile_registry = _compile_stateprep_registry
else:
global _compile_statemap_registry
_compile_registry = _compile_statemap_registry

workflow = Workflow(workflow)

global _compile_registry
new_workflow = {optimization_level: workflow}
if key in _compile_registry:
if optimization_level in _compile_registry[key]:
m = f'Overwritting workflow for {key} at level '
m += f'{optimization_level}. If multiple Namespace packages are '
m += 'installed, ensure that their __init__.py files do not '
m += 'attempt to overwrite the same default Workflows.'
m += f'{optimization_level} for target type {target_type}.'
m += 'If multiple Namespace packages are installed, ensure'
m += 'that their __init__.py files do not attempt to'
m += 'overwrite the same default Workflows.'
warnings.warn(m)
_compile_registry[key].update(new_workflow)
else:
Expand Down
63 changes: 48 additions & 15 deletions tests/compiler/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

from bqskit.compiler.compile import compile
from bqskit.compiler.machine import MachineModel
from bqskit.compiler.registry import _compile_registry
from bqskit.compiler.registry import _compile_circuit_registry
from bqskit.compiler.registry import _compile_statemap_registry
from bqskit.compiler.registry import _compile_stateprep_registry
from bqskit.compiler.registry import _compile_unitary_registry
from bqskit.compiler.registry import register_workflow
from bqskit.compiler.workflow import Workflow
from bqskit.compiler.workflow import WorkflowLike
Expand Down Expand Up @@ -74,28 +77,58 @@ class TestRegisterWorkflow:
@pytest.fixture(autouse=True)
def setup(self) -> None:
# global _compile_registry
_compile_registry.clear()
_compile_circuit_registry.clear()
_compile_unitary_registry.clear()
_compile_statemap_registry.clear()
_compile_stateprep_registry.clear()

def test_register_workflow(self) -> None:
global _compile_registry
assert _compile_registry == {}
global _compile_circuit_registry
global _compile_unitary_registry
global _compile_statemap_registry
global _compile_stateprep_registry
assert _compile_circuit_registry == {}
assert _compile_unitary_registry == {}
assert _compile_statemap_registry == {}
assert _compile_stateprep_registry == {}
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QuickPartitioner(), ScanningGateRemovalPass()]
register_workflow(machine, workflow, 1)
assert machine in _compile_registry
assert 1 in _compile_registry[machine]
assert workflow_match(_compile_registry[machine][1], workflow)
circuit_workflow = [QuickPartitioner(), ScanningGateRemovalPass()]
other_workflow = [QuickPartitioner(), QSearchSynthesisPass()]
register_workflow(machine, circuit_workflow, 1, 'circuit')
register_workflow(machine, other_workflow, 1, 'unitary')
register_workflow(machine, other_workflow, 1, 'statemap')
register_workflow(machine, other_workflow, 1, 'stateprep')
assert machine in _compile_circuit_registry
assert 1 in _compile_circuit_registry[machine]
assert workflow_match(
_compile_circuit_registry[machine][1], circuit_workflow,
)
assert machine in _compile_unitary_registry
assert 1 in _compile_unitary_registry[machine]
assert workflow_match(
_compile_unitary_registry[machine][1], other_workflow,
)
assert machine in _compile_statemap_registry
assert 1 in _compile_statemap_registry[machine]
assert workflow_match(
_compile_statemap_registry[machine][1], other_workflow,
)
assert machine in _compile_stateprep_registry
assert 1 in _compile_stateprep_registry[machine]
assert workflow_match(
_compile_stateprep_registry[machine][1], other_workflow,
)

def test_custom_compile_machine(self) -> None:
global _compile_registry
assert _compile_registry == {}
global _compile_circuit_registry
assert _compile_circuit_registry == {}
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QuickPartitioner(2)]
register_workflow(machine, workflow, 1)
register_workflow(machine, workflow, 1, 'circuit')
circuit = simple_circuit(num_qudits, gateset)
result = compile(circuit, machine)
assert unitary_match(result, circuit)
Expand All @@ -105,13 +138,13 @@ def test_custom_compile_machine(self) -> None:
assert result.gate_counts == circuit.gate_counts

def test_custom_opt_level(self) -> None:
global _compile_registry
assert _compile_registry == {}
global _compile_circuit_registry
assert _compile_circuit_registry == {}
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QSearchSynthesisPass()]
register_workflow(machine, workflow, 2)
register_workflow(machine, workflow, 2, 'circuit')
circuit = simple_circuit(num_qudits, gateset)
result = compile(circuit, machine, optimization_level=2)
assert unitary_match(result, circuit)
Expand Down
Loading