diff --git a/bqskit/compiler/compile.py b/bqskit/compiler/compile.py index 83c3d888..4536f069 100644 --- a/bqskit/compiler/compile.py +++ b/bqskit/compiler/compile.py @@ -14,7 +14,11 @@ from bqskit.compiler.compiler import Compiler from bqskit.compiler.machine import MachineModel from bqskit.compiler.passdata import PassData -from bqskit.compiler.registry import _compile_registry +from bqskit.compiler.registry import _compile_circuit_registry +from bqskit.compiler.registry import _compile_statemap_registry +from bqskit.compiler.registry import _compile_stateprep_registry +from bqskit.compiler.registry import _compile_unitary_registry +from bqskit.compiler.registry import model_registered_target_types from bqskit.compiler.workflow import Workflow from bqskit.compiler.workflow import WorkflowLike from bqskit.ir.circuit import Circuit @@ -622,8 +626,11 @@ def type_and_check_input(input: CompilationInputLike) -> CompilationInput: if isinstance(typed_input, Circuit): in_circuit = typed_input + elif isinstance(typed_input, UnitaryMatrix): + in_circuit = Circuit.from_unitary(typed_input) + else: - in_circuit = Circuit(1) + in_circuit = Circuit(typed_input.num_qudits, typed_input.radixes) # Perform the compilation out, data = compiler.compile(in_circuit, workflow, True) @@ -669,11 +676,7 @@ def build_workflow( if model is None: model = MachineModel(input.num_qudits, radixes=input.radixes) - # Use a registered workflow if model is found in the registry for a given - # optimization_level - if model in _compile_registry: - if optimization_level in _compile_registry[model]: - return _compile_registry[model][optimization_level] + model_registered_types = model_registered_target_types(model) if isinstance(input, Circuit): if input.num_qudits > max_synthesis_size: @@ -691,6 +694,16 @@ def build_workflow( 'Unable to compile circuit with gate larger than' ' max_synthesis_size.\nConsider adjusting it.', ) + # Use a registered workflow if model is found in the circuit registry + # for a given optimization_level + if model in _compile_circuit_registry: + if optimization_level in _compile_circuit_registry[model]: + return _compile_circuit_registry[model][optimization_level] + elif len(model_registered_types) > 0: + m = f'MachineModel {model} is registered for inputs of type in ' + m += f'{model_registered_types}, but input is {type(input)}. ' + m += f'You may need to register a Workflow for type {type(input)}.' + warnings.warn(m) return _circuit_workflow( model, @@ -708,6 +721,16 @@ def build_workflow( 'Unable to compile unitary with size larger than' ' max_synthesis_size.\nConsider adjusting it.', ) + # Use a registered workflow if model is found in the unitary registry + # for a given optimization_level + if model in _compile_unitary_registry: + if optimization_level in _compile_unitary_registry[model]: + return _compile_unitary_registry[model][optimization_level] + elif len(model_registered_types) > 0: + m = f'MachineModel {model} is registered for inputs of type in ' + m += f'{model_registered_types}, but input is {type(input)}. ' + m += f'You may need to register a Workflow for type {type(input)}.' + warnings.warn(m) return _synthesis_workflow( input, @@ -726,6 +749,16 @@ def build_workflow( 'Unable to compile states with size larger than' ' max_synthesis_size.\nConsider adjusting it.', ) + # Use a registered workflow if model is found in the stateprep registry + # for a given optimization_level + if model in _compile_stateprep_registry: + if optimization_level in _compile_stateprep_registry[model]: + return _compile_stateprep_registry[model][optimization_level] + elif len(model_registered_types) > 0: + m = f'MachineModel {model} is registered for inputs of type in ' + m += f'{model_registered_types}, but input is {type(input)}. ' + m += f'You may need to register a Workflow for type {type(input)}.' + warnings.warn(m) return _stateprep_workflow( input, @@ -744,6 +777,16 @@ def build_workflow( 'Unable to compile state systems with size larger than' ' max_synthesis_size.\nConsider adjusting it.', ) + # Use a registered workflow if model is found in the statemap registry + # for a given optimization_level + if model in _compile_statemap_registry: + if optimization_level in _compile_statemap_registry[model]: + return _compile_statemap_registry[model][optimization_level] + elif len(model_registered_types) > 0: + m = f'MachineModel {model} is registered for inputs of type in ' + m += f'{model_registered_types}, but input is {type(input)}. ' + m += f'You may need to register a Workflow for type {type(input)}.' + warnings.warn(m) return _statemap_workflow( input, diff --git a/bqskit/compiler/registry.py b/bqskit/compiler/registry.py index b3cb1b4d..9a2d5d45 100644 --- a/bqskit/compiler/registry.py +++ b/bqskit/compiler/registry.py @@ -8,13 +8,44 @@ from bqskit.compiler.workflow import WorkflowLike -_compile_registry: dict[MachineModel, dict[int, Workflow]] = {} +_compile_circuit_registry: dict[MachineModel, dict[int, Workflow]] = {} +_compile_unitary_registry: dict[MachineModel, dict[int, Workflow]] = {} +_compile_stateprep_registry: dict[MachineModel, dict[int, Workflow]] = {} +_compile_statemap_registry: dict[MachineModel, dict[int, Workflow]] = {} + + +def model_registered_target_types(key: MachineModel) -> list[str]: + """ + Return a list of target_types for which key is registered. + + Args: + key (MachineModel): A MachineModel to check for. + + Returns: + (list[str]): If `key` has been registered in any of the registry, the + name of that target type will be contained in this list. + """ + global _compile_circuit_registry + global _compile_unitary_registry + global _compile_stateprep_registry + global _compile_statemap_registry + registered_types = [] + if key in _compile_circuit_registry: + registered_types.append('circuit') + if key in _compile_unitary_registry: + registered_types.append('unitary') + if key in _compile_stateprep_registry: + registered_types.append('stateprep') + if key in _compile_statemap_registry: + registered_types.append('statemap') + return registered_types def register_workflow( key: MachineModel, workflow: WorkflowLike, optimization_level: int, + target_type: str = 'circuit', ) -> None: """ Register a workflow for a given MachineModel. @@ -34,30 +65,54 @@ def register_workflow( be executed if the MachineModel in a call to `compile` matches `key`. If `key` is already registered, a warning will be logged. - optimization_level ptional[int): The optimization level with which - to register the workflow. If no level is provided, the Workflow - will be registered as level 1. + optimization_level (int): The optimization level with which to + register the workflow. + + target_type (str): Register a workflow for targets of this type. Must + be 'circuit', 'unitary', 'stateprep', or 'statemap'. + (Default: 'circuit') Example: model_t = SpecificMachineModel(num_qudits, radixes) workflow = [QuickPartitioner(3), NewFangledOptimization()] - register_workflow(model_t, workflow, level) + register_workflow(model_t, workflow, level, 'circuit') ... new_circuit = compile(circuit, model_t, optimization_level=level) Raises: Warning: If a workflow for a given optimization_level is overwritten. + + ValueError: If `target_type` is not 'circuit', 'unitary', 'stateprep', + or 'statemap'. """ + if target_type not in ['circuit', 'unitary', 'stateprep', 'statemap']: + m = 'target_type must be "circuit", "unitary", "stateprep", or ' + m += f'"statemap", got {target_type}.' + raise ValueError(m) + + if target_type == 'circuit': + global _compile_circuit_registry + _compile_registry = _compile_circuit_registry + elif target_type == 'unitary': + global _compile_unitary_registry + _compile_registry = _compile_unitary_registry + elif target_type == 'stateprep': + global _compile_stateprep_registry + _compile_registry = _compile_stateprep_registry + else: + global _compile_statemap_registry + _compile_registry = _compile_statemap_registry + workflow = Workflow(workflow) - global _compile_registry new_workflow = {optimization_level: workflow} if key in _compile_registry: if optimization_level in _compile_registry[key]: m = f'Overwritting workflow for {key} at level ' - m += f'{optimization_level}. If multiple Namespace packages are ' - m += 'installed, ensure that their __init__.py files do not ' - m += 'attempt to overwrite the same default Workflows.' + m += f'{optimization_level} for target type {target_type}.' + m += 'If multiple Namespace packages are installed, ensure' + m += 'that their __init__.py files do not attempt to' + m += 'overwrite the same default Workflows.' warnings.warn(m) _compile_registry[key].update(new_workflow) else: diff --git a/tests/compiler/test_registry.py b/tests/compiler/test_registry.py index 6371211c..4cc6065b 100644 --- a/tests/compiler/test_registry.py +++ b/tests/compiler/test_registry.py @@ -9,7 +9,10 @@ from bqskit.compiler.compile import compile from bqskit.compiler.machine import MachineModel -from bqskit.compiler.registry import _compile_registry +from bqskit.compiler.registry import _compile_circuit_registry +from bqskit.compiler.registry import _compile_statemap_registry +from bqskit.compiler.registry import _compile_stateprep_registry +from bqskit.compiler.registry import _compile_unitary_registry from bqskit.compiler.registry import register_workflow from bqskit.compiler.workflow import Workflow from bqskit.compiler.workflow import WorkflowLike @@ -74,28 +77,58 @@ class TestRegisterWorkflow: @pytest.fixture(autouse=True) def setup(self) -> None: # global _compile_registry - _compile_registry.clear() + _compile_circuit_registry.clear() + _compile_unitary_registry.clear() + _compile_statemap_registry.clear() + _compile_stateprep_registry.clear() def test_register_workflow(self) -> None: - global _compile_registry - assert _compile_registry == {} + global _compile_circuit_registry + global _compile_unitary_registry + global _compile_statemap_registry + global _compile_stateprep_registry + assert _compile_circuit_registry == {} + assert _compile_unitary_registry == {} + assert _compile_statemap_registry == {} + assert _compile_stateprep_registry == {} gateset = [CZGate(), HGate(), RZGate()] num_qudits = 3 machine = MachineModel(num_qudits, gate_set=gateset) - workflow = [QuickPartitioner(), ScanningGateRemovalPass()] - register_workflow(machine, workflow, 1) - assert machine in _compile_registry - assert 1 in _compile_registry[machine] - assert workflow_match(_compile_registry[machine][1], workflow) + circuit_workflow = [QuickPartitioner(), ScanningGateRemovalPass()] + other_workflow = [QuickPartitioner(), QSearchSynthesisPass()] + register_workflow(machine, circuit_workflow, 1, 'circuit') + register_workflow(machine, other_workflow, 1, 'unitary') + register_workflow(machine, other_workflow, 1, 'statemap') + register_workflow(machine, other_workflow, 1, 'stateprep') + assert machine in _compile_circuit_registry + assert 1 in _compile_circuit_registry[machine] + assert workflow_match( + _compile_circuit_registry[machine][1], circuit_workflow, + ) + assert machine in _compile_unitary_registry + assert 1 in _compile_unitary_registry[machine] + assert workflow_match( + _compile_unitary_registry[machine][1], other_workflow, + ) + assert machine in _compile_statemap_registry + assert 1 in _compile_statemap_registry[machine] + assert workflow_match( + _compile_statemap_registry[machine][1], other_workflow, + ) + assert machine in _compile_stateprep_registry + assert 1 in _compile_stateprep_registry[machine] + assert workflow_match( + _compile_stateprep_registry[machine][1], other_workflow, + ) def test_custom_compile_machine(self) -> None: - global _compile_registry - assert _compile_registry == {} + global _compile_circuit_registry + assert _compile_circuit_registry == {} gateset = [CZGate(), HGate(), RZGate()] num_qudits = 3 machine = MachineModel(num_qudits, gate_set=gateset) workflow = [QuickPartitioner(2)] - register_workflow(machine, workflow, 1) + register_workflow(machine, workflow, 1, 'circuit') circuit = simple_circuit(num_qudits, gateset) result = compile(circuit, machine) assert unitary_match(result, circuit) @@ -105,13 +138,13 @@ def test_custom_compile_machine(self) -> None: assert result.gate_counts == circuit.gate_counts def test_custom_opt_level(self) -> None: - global _compile_registry - assert _compile_registry == {} + global _compile_circuit_registry + assert _compile_circuit_registry == {} gateset = [CZGate(), HGate(), RZGate()] num_qudits = 3 machine = MachineModel(num_qudits, gate_set=gateset) workflow = [QSearchSynthesisPass()] - register_workflow(machine, workflow, 2) + register_workflow(machine, workflow, 2, 'circuit') circuit = simple_circuit(num_qudits, gateset) result = compile(circuit, machine, optimization_level=2) assert unitary_match(result, circuit)