diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bcdf16e4f..5fd25af39 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ ci: skip: [mypy] repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.6.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -30,7 +30,7 @@ repos: - --wrap-summaries=80 - --wrap-descriptions=80 - repo: https://github.com/pre-commit/mirrors-autopep8 - rev: v2.0.2 + rev: v2.0.4 hooks: - id: autopep8 args: @@ -39,13 +39,13 @@ repos: - --ignore=E731 exclude: 'tests/ext.*' - repo: https://github.com/asottile/pyupgrade - rev: v3.10.1 + rev: v3.17.0 hooks: - id: pyupgrade args: - --py38-plus - repo: https://github.com/asottile/reorder_python_imports - rev: v3.10.0 + rev: v3.13.0 hooks: - id: reorder-python-imports args: @@ -54,25 +54,25 @@ repos: - --py37-plus exclude: 'tests/ext.*' - repo: https://github.com/asottile/add-trailing-comma - rev: v3.0.1 + rev: v3.1.0 hooks: - id: add-trailing-comma args: - --py36-plus - repo: https://github.com/PyCQA/autoflake - rev: v2.2.0 + rev: v2.3.1 hooks: - id: autoflake args: - --in-place - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.5.0 + rev: v1.11.1 hooks: - id: mypy exclude: tests/qis/test_pauli.py additional_dependencies: ["numpy>=1.21"] - repo: https://github.com/PyCQA/flake8 - rev: 6.1.0 + rev: 7.1.1 hooks: - id: flake8 args: diff --git a/bqskit/__init__.py b/bqskit/__init__.py index dd6a87128..83ca808bb 100644 --- a/bqskit/__init__.py +++ b/bqskit/__init__.py @@ -1,91 +1,29 @@ """The Berkeley Quantum Synthesis Toolkit Python Package.""" from __future__ import annotations -import logging -from sys import stdout as _stdout +from typing import Any -from .version import __version__ # noqa: F401 -from .version import __version_info__ # noqa: F401 -from bqskit.compiler.compile import compile -from bqskit.compiler.machine import MachineModel -from bqskit.ir.circuit import Circuit -from bqskit.ir.lang import register_language as _register_language -from bqskit.ir.lang.qasm2 import OPENQASM2Language as _qasm +from bqskit._logging import disable_logging +from bqskit._logging import enable_logging +from bqskit._version import __version__ # noqa: F401 +from bqskit._version import __version_info__ # noqa: F401 -# Initialize Logging -_logging_initialized = False +def __getattr__(name: str) -> Any: + # Lazy imports + if name == 'compile': + from bqskit.compiler.compile import compile + return compile -def enable_logging(verbose: bool = False) -> None: - """ - Enable logging for BQSKit. + if name == 'Circuit': + from bqskit.ir.circuit import Circuit + return Circuit - Args: - verbose (bool): If set to True, will print more verbose messages. - Defaults to False. - """ - global _logging_initialized - if not _logging_initialized: - _logger = logging.getLogger('bqskit') - _handler = logging.StreamHandler(_stdout) - _handler.setLevel(0) - _fmt_header = '%(asctime)s.%(msecs)03d - %(levelname)-8s |' - _fmt_message = ' %(name)s: %(message)s' - _fmt = _fmt_header + _fmt_message - _formatter = logging.Formatter(_fmt, '%H:%M:%S') - _handler.setFormatter(_formatter) - _logger.addHandler(_handler) - _logging_initialized = True + if name == 'MachineModel': + from bqskit.compiler.machine import MachineModel + return MachineModel - level = logging.DEBUG if verbose else logging.INFO - logging.getLogger('bqskit').setLevel(level) - - -def disable_logging() -> None: - """Disable logging for BQSKit.""" - logging.getLogger('bqskit').setLevel(logging.CRITICAL) - - -def enable_dashboard() -> None: - import warnings - warnings.warn( - 'Dask has been removed from BQSKit. As a result, the' - ' enable_dashboard method has been removed.' - 'This warning will turn into an error in a future update.', - DeprecationWarning, - ) - - -def disable_dashboard() -> None: - import warnings - warnings.warn( - 'Dask has been removed from BQSKit. As a result, the' - ' disable_dashboard method has been removed.' - 'This warning will turn into an error in a future update.', - DeprecationWarning, - ) - - -def disable_parallelism() -> None: - import warnings - warnings.warn( - 'The disable_parallelism method has been removed.' - ' Instead, set the "num_workers" parameter to 1 during ' - 'Compiler construction. This warning will turn into' - 'an error in a future update.', - DeprecationWarning, - ) - - -def enable_parallelism() -> None: - import warnings - warnings.warn( - 'The enable_parallelism method has been removed.' - ' Instead, set the "num_workers" parameter to 1 during ' - 'Compiler construction. This warning will turn into' - 'an error in a future update.', - DeprecationWarning, - ) + raise AttributeError(f'module {__name__} has no attribute {name}') __all__ = [ @@ -95,6 +33,3 @@ def enable_parallelism() -> None: 'enable_logging', 'disable_logging', ] - -# Register supported languages -_register_language('qasm', _qasm()) diff --git a/bqskit/_logging.py b/bqskit/_logging.py new file mode 100644 index 000000000..59b079156 --- /dev/null +++ b/bqskit/_logging.py @@ -0,0 +1,38 @@ +"""This module contains the logging configuration and methods for BQSKit.""" +from __future__ import annotations + +import logging +from sys import stdout as _stdout + + +_logging_initialized = False + + +def enable_logging(verbose: bool = False) -> None: + """ + Enable logging for BQSKit. + + Args: + verbose (bool): If set to True, will print more verbose messages. + Defaults to False. + """ + global _logging_initialized + if not _logging_initialized: + _logger = logging.getLogger('bqskit') + _handler = logging.StreamHandler(_stdout) + _handler.setLevel(0) + _fmt_header = '%(asctime)s.%(msecs)03d - %(levelname)-8s |' + _fmt_message = ' %(name)s: %(message)s' + _fmt = _fmt_header + _fmt_message + _formatter = logging.Formatter(_fmt, '%H:%M:%S') + _handler.setFormatter(_formatter) + _logger.addHandler(_handler) + _logging_initialized = True + + level = logging.DEBUG if verbose else logging.INFO + logging.getLogger('bqskit').setLevel(level) + + +def disable_logging() -> None: + """Disable logging for BQSKit.""" + logging.getLogger('bqskit').setLevel(logging.CRITICAL) diff --git a/bqskit/version.py b/bqskit/_version.py similarity index 83% rename from bqskit/version.py rename to bqskit/_version.py index b035796df..726c5cd4b 100644 --- a/bqskit/version.py +++ b/bqskit/_version.py @@ -1,4 +1,4 @@ """This module contains the version information for BQSKit.""" from __future__ import annotations -__version_info__ = ('1', '1', '2') +__version_info__ = ('1', '2', '0') __version__ = '.'.join(__version_info__[:3]) + ''.join(__version_info__[3:]) diff --git a/bqskit/compiler/__init__.py b/bqskit/compiler/__init__.py index f7048aa56..fde8f87ee 100644 --- a/bqskit/compiler/__init__.py +++ b/bqskit/compiler/__init__.py @@ -37,8 +37,9 @@ """ from __future__ import annotations +from typing import Any + from bqskit.compiler.basepass import BasePass -from bqskit.compiler.compile import compile from bqskit.compiler.compiler import Compiler from bqskit.compiler.gateset import GateSet from bqskit.compiler.gateset import GateSetLike @@ -49,6 +50,19 @@ from bqskit.compiler.workflow import Workflow from bqskit.compiler.workflow import WorkflowLike + +def __getattr__(name: str) -> Any: + # Lazy imports + if name == 'compile': + # TODO: fix this (high-priority), overlap between module and function + from bqskit.compiler.compile import compile + return compile + + # TODO: Move compile to a different subpackage and deprecate import + + raise AttributeError(f'module {__name__} has no attribute {name}') + + __all__ = [ 'BasePass', 'compile', diff --git a/bqskit/compiler/compile.py b/bqskit/compiler/compile.py index 4824c42ad..4536f0697 100644 --- a/bqskit/compiler/compile.py +++ b/bqskit/compiler/compile.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging +import math import warnings from typing import Any from typing import Literal @@ -10,11 +11,14 @@ from typing import TYPE_CHECKING from typing import Union -import numpy as np - from bqskit.compiler.compiler import Compiler from bqskit.compiler.machine import MachineModel from bqskit.compiler.passdata import PassData +from bqskit.compiler.registry import _compile_circuit_registry +from bqskit.compiler.registry import _compile_statemap_registry +from bqskit.compiler.registry import _compile_stateprep_registry +from bqskit.compiler.registry import _compile_unitary_registry +from bqskit.compiler.registry import model_registered_target_types from bqskit.compiler.workflow import Workflow from bqskit.compiler.workflow import WorkflowLike from bqskit.ir.circuit import Circuit @@ -582,7 +586,7 @@ def type_and_check_input(input: CompilationInputLike) -> CompilationInput: if error_threshold is not None: for i, data in enumerate(datas): error = data.error - nonsq_error = 1 - np.sqrt(max(1 - (error * error), 0)) + nonsq_error = 1 - math.sqrt(max(1 - (error * error), 0)) if nonsq_error > error_threshold: warnings.warn( 'Upper bound on error is greater than set threshold:' @@ -622,8 +626,11 @@ def type_and_check_input(input: CompilationInputLike) -> CompilationInput: if isinstance(typed_input, Circuit): in_circuit = typed_input + elif isinstance(typed_input, UnitaryMatrix): + in_circuit = Circuit.from_unitary(typed_input) + else: - in_circuit = Circuit(1) + in_circuit = Circuit(typed_input.num_qudits, typed_input.radixes) # Perform the compilation out, data = compiler.compile(in_circuit, workflow, True) @@ -631,7 +638,7 @@ def type_and_check_input(input: CompilationInputLike) -> CompilationInput: # Log error if necessary if error_threshold is not None: error = data.error - nonsq_error = 1 - np.sqrt(max(1 - (error * error), 0)) + nonsq_error = 1 - math.sqrt(max(1 - (error * error), 0)) if nonsq_error > error_threshold: warnings.warn( 'Upper bound on error is greater than set threshold:' @@ -669,6 +676,8 @@ def build_workflow( if model is None: model = MachineModel(input.num_qudits, radixes=input.radixes) + model_registered_types = model_registered_target_types(model) + if isinstance(input, Circuit): if input.num_qudits > max_synthesis_size: if any( @@ -685,6 +694,16 @@ def build_workflow( 'Unable to compile circuit with gate larger than' ' max_synthesis_size.\nConsider adjusting it.', ) + # Use a registered workflow if model is found in the circuit registry + # for a given optimization_level + if model in _compile_circuit_registry: + if optimization_level in _compile_circuit_registry[model]: + return _compile_circuit_registry[model][optimization_level] + elif len(model_registered_types) > 0: + m = f'MachineModel {model} is registered for inputs of type in ' + m += f'{model_registered_types}, but input is {type(input)}. ' + m += f'You may need to register a Workflow for type {type(input)}.' + warnings.warn(m) return _circuit_workflow( model, @@ -702,6 +721,16 @@ def build_workflow( 'Unable to compile unitary with size larger than' ' max_synthesis_size.\nConsider adjusting it.', ) + # Use a registered workflow if model is found in the unitary registry + # for a given optimization_level + if model in _compile_unitary_registry: + if optimization_level in _compile_unitary_registry[model]: + return _compile_unitary_registry[model][optimization_level] + elif len(model_registered_types) > 0: + m = f'MachineModel {model} is registered for inputs of type in ' + m += f'{model_registered_types}, but input is {type(input)}. ' + m += f'You may need to register a Workflow for type {type(input)}.' + warnings.warn(m) return _synthesis_workflow( input, @@ -720,6 +749,16 @@ def build_workflow( 'Unable to compile states with size larger than' ' max_synthesis_size.\nConsider adjusting it.', ) + # Use a registered workflow if model is found in the stateprep registry + # for a given optimization_level + if model in _compile_stateprep_registry: + if optimization_level in _compile_stateprep_registry[model]: + return _compile_stateprep_registry[model][optimization_level] + elif len(model_registered_types) > 0: + m = f'MachineModel {model} is registered for inputs of type in ' + m += f'{model_registered_types}, but input is {type(input)}. ' + m += f'You may need to register a Workflow for type {type(input)}.' + warnings.warn(m) return _stateprep_workflow( input, @@ -738,6 +777,16 @@ def build_workflow( 'Unable to compile state systems with size larger than' ' max_synthesis_size.\nConsider adjusting it.', ) + # Use a registered workflow if model is found in the statemap registry + # for a given optimization_level + if model in _compile_statemap_registry: + if optimization_level in _compile_statemap_registry[model]: + return _compile_statemap_registry[model][optimization_level] + elif len(model_registered_types) > 0: + m = f'MachineModel {model} is registered for inputs of type in ' + m += f'{model_registered_types}, but input is {type(input)}. ' + m += f'You may need to register a Workflow for type {type(input)}.' + warnings.warn(m) return _statemap_workflow( input, diff --git a/bqskit/compiler/compiler.py b/bqskit/compiler/compiler.py index 55bb50a1d..e68d5a6d4 100644 --- a/bqskit/compiler/compiler.py +++ b/bqskit/compiler/compiler.py @@ -4,17 +4,18 @@ import atexit import functools import logging +import pickle import signal import subprocess import sys import time import uuid -import warnings from multiprocessing.connection import Client from multiprocessing.connection import Connection from subprocess import Popen from types import FrameType from typing import Literal +from typing import MutableMapping from typing import overload from typing import TYPE_CHECKING @@ -71,6 +72,7 @@ def __init__( num_workers: int = -1, runtime_log_level: int = logging.WARNING, worker_port: int = default_worker_port, + num_blas_threads: int = 1, ) -> None: """ Construct a Compiler object. @@ -99,30 +101,42 @@ def __init__( worker_port (int): The optional port to pass to an attached runtime. See :obj:`~bqskit.runtime.attached.AttachedServer` for more info. + + num_blas_threads (int): The number of threads to use in the + BLAS libraries on the worker nodes. (Defaults to 1) """ self.p: Popen | None = None # type: ignore self.conn: Connection | None = None - atexit.register(self.close) + _compiler_instances.add(self) if ip is None: ip = 'localhost' - self._start_server(num_workers, runtime_log_level, worker_port) + self._start_server( + num_workers, + runtime_log_level, + worker_port, + num_blas_threads, + ) - self._connect_to_server(ip, port) + self._connect_to_server(ip, port, self.p is not None) def _start_server( self, num_workers: int, runtime_log_level: int, worker_port: int, + num_blas_threads: int, ) -> None: """ Start an attached serer with `num_workers` workers. See :obj:`~bqskit.runtime.attached.AttachedServer` for more info. """ - params = f'{num_workers}, {runtime_log_level}, {worker_port=}' + params = f'{num_workers}, ' + params += f'log_level={runtime_log_level}, ' + params += f'{worker_port=}, ' + params += f'{num_blas_threads=}, ' import_str = 'from bqskit.runtime.attached import start_attached_server' launch_str = f'{import_str}; start_attached_server({params})' if sys.platform == 'win32': @@ -132,24 +146,45 @@ def _start_server( self.p = Popen([sys.executable, '-c', launch_str], creationflags=flags) _logger.debug('Starting runtime server process.') - def _connect_to_server(self, ip: str, port: int) -> None: + def _connect_to_server(self, ip: str, port: int, attached: bool) -> None: """Connect to a runtime server at `ip` and `port`.""" max_retries = 8 wait_time = .25 - for _ in range(max_retries): + current_retry = 0 + while current_retry < max_retries or attached: try: family = 'AF_INET' if sys.platform == 'win32' else None conn = Client((ip, port), family) except ConnectionRefusedError: + if wait_time > 4: + _logger.warning( + 'Connection refused by runtime server.' + ' Retrying in %s seconds.', wait_time, + ) + if wait_time > 16 and attached: + _logger.warning( + 'Connection is still refused by runtime server.' + ' This may be due to the server not being started.' + ' You may want to check the server logs, by starting' + ' the compiler with "runtime_log_level" set. You' + ' can also try launching the bqskit runtime in' + ' detached mode. See the bqskit runtime documentation' + ' for more information:' + ' https://bqskit.readthedocs.io/en/latest/guides/' + 'distributing.html', + ) time.sleep(wait_time) wait_time *= 2 + current_retry += 1 else: self.conn = conn handle = functools.partial(sigint_handler, compiler=self) self.old_signal = signal.signal(signal.SIGINT, handle) if self.conn is None: raise RuntimeError('Connection unexpectedly none.') - self.conn.send((RuntimeMessage.CONNECT, None)) + msg, payload = self._send_recv(RuntimeMessage.CONNECT, sys.path) + if msg != RuntimeMessage.READY: + raise RuntimeError(f'Unexpected message type: {msg}.') _logger.debug('Successfully connected to runtime server.') return raise RuntimeError('Client connection refused') @@ -221,28 +256,25 @@ def close(self) -> None: # Reset interrupt signal handler and remove exit handler if hasattr(self, 'old_signal'): signal.signal(signal.SIGINT, self.old_signal) + del self.old_signal - def __del__(self) -> None: - self.close() - atexit.unregister(self.close) - _logger.debug('Compiler successfully shutdown.') + _compiler_instances.discard(self) + _logger.debug('Compiler has been closed.') def submit( self, - task_or_circuit: CompilationTask | Circuit, - workflow: WorkflowLike | None = None, + circuit: Circuit, + workflow: WorkflowLike, request_data: bool = False, logging_level: int | None = None, max_logging_depth: int = -1, + data: MutableMapping[str, Any] | None = None, ) -> uuid.UUID: """ Submit a compilation job to the Compiler. Args: - task_or_circuit (CompilationTask | Circuit): The task to compile, - or the input circuit. If a task is specified, no other - argument should be specified. If a task is not specified, - the circuit must be paired with a workflow argument. + circuit (Circuit): The input circuit to be compiled. workflow (WorkflowLike): The compilation job submitted is defined by executing this workflow on the input circuit. @@ -262,91 +294,48 @@ def submit( tasks equal opportunity to log. Returns: - (uuid.UUID): The ID of the generated task in the system. This + uuid.UUID: The ID of the generated task in the system. This ID can be used to check the status of, cancel, and request the result of the task. """ # Build CompilationTask - if isinstance(task_or_circuit, CompilationTask): - if workflow is not None: - raise ValueError( - 'Cannot specify workflow and task.' - ' Either specify a workflow and circuit or a task alone.', - ) - - task = task_or_circuit - - else: - if workflow is None: - m = 'Must specify workflow when providing a circuit to submit.' - raise TypeError(m) - - task = CompilationTask(task_or_circuit, Workflow(workflow)) + task = CompilationTask(circuit, Workflow(workflow)) # Set task configuration task.request_data = request_data task.logging_level = logging_level or self._discover_lowest_log_level() task.max_logging_depth = max_logging_depth + if data is not None: + task.data.update(data) # Submit task to runtime self._send(RuntimeMessage.SUBMIT, task) return task.task_id - def status(self, task_id: CompilationTask | uuid.UUID) -> CompilationStatus: - """Retrieve the status of the specified task.""" - if isinstance(task_id, CompilationTask): - warnings.warn( - 'Request a status from a CompilationTask is deprecated.\n' - ' Instead, pass a task ID to request a status.\n' - ' `compiler.submit` returns a task id, and you can get an\n' - ' ID from a task via `task.task_id`.\n' - ' This warning will turn into an error in a future update.', - DeprecationWarning, - ) - task_id = task_id.task_id - assert isinstance(task_id, uuid.UUID) + def status(self, task_id: uuid.UUID) -> CompilationStatus: + """ + Retrieve the status of the specified task. + + Args: + task_id (uuid.UUID): The ID of the task to check. + Returns: + CompilationStatus: The status of the task. + """ msg, payload = self._send_recv(RuntimeMessage.STATUS, task_id) if msg != RuntimeMessage.STATUS: raise RuntimeError(f'Unexpected message type: {msg}.') return payload - def result( - self, - task_id: CompilationTask | uuid.UUID, - ) -> Circuit | tuple[Circuit, PassData]: + def result(self, task_id: uuid.UUID) -> Circuit | tuple[Circuit, PassData]: """Block until the task is finished, return its result.""" - if isinstance(task_id, CompilationTask): - warnings.warn( - 'Request a result from a CompilationTask is deprecated.' - ' Instead, pass a task ID to request a result.\n' - ' `compiler.submit` returns a task id, and you can get an\n' - ' ID from a task via `task.task_id`.\n' - ' This warning will turn into an error in a future update.', - DeprecationWarning, - ) - task_id = task_id.task_id - assert isinstance(task_id, uuid.UUID) - msg, payload = self._send_recv(RuntimeMessage.REQUEST, task_id) if msg != RuntimeMessage.RESULT: raise RuntimeError(f'Unexpected message type: {msg}.') return payload - def cancel(self, task_id: CompilationTask | uuid.UUID) -> bool: + def cancel(self, task_id: uuid.UUID) -> bool: """Cancel the execution of a task in the system.""" - if isinstance(task_id, CompilationTask): - warnings.warn( - 'Cancelling a CompilationTask is deprecated. Instead,' - ' Instead, pass a task ID to cancel a task.\n' - ' `compiler.submit` returns a task id, and you can get an\n' - ' ID from a task via `task.task_id`.\n' - ' This warning will turn into an error in a future update.', - DeprecationWarning, - ) - task_id = task_id.task_id - assert isinstance(task_id, uuid.UUID) - msg, _ = self._send_recv(RuntimeMessage.CANCEL, task_id) if msg != RuntimeMessage.CANCEL: raise RuntimeError(f'Unexpected message type: {msg}.') @@ -355,63 +344,51 @@ def cancel(self, task_id: CompilationTask | uuid.UUID) -> bool: @overload def compile( self, - task_or_circuit: CompilationTask, - ) -> Circuit | tuple[Circuit, PassData]: - ... - - @overload - def compile( - self, - task_or_circuit: Circuit, + circuit: Circuit, workflow: WorkflowLike, request_data: Literal[False] = ..., logging_level: int | None = ..., max_logging_depth: int = ..., + data: MutableMapping[str, Any] | None = ..., ) -> Circuit: ... @overload def compile( self, - task_or_circuit: Circuit, + circuit: Circuit, workflow: WorkflowLike, request_data: Literal[True], logging_level: int | None = ..., max_logging_depth: int = ..., + data: MutableMapping[str, Any] | None = ..., ) -> tuple[Circuit, PassData]: ... @overload def compile( self, - task_or_circuit: Circuit, + circuit: Circuit, workflow: WorkflowLike, request_data: bool, logging_level: int | None = ..., max_logging_depth: int = ..., + data: MutableMapping[str, Any] | None = ..., ) -> Circuit | tuple[Circuit, PassData]: ... def compile( self, - task_or_circuit: CompilationTask | Circuit, - workflow: WorkflowLike | None = None, + circuit: Circuit, + workflow: WorkflowLike, request_data: bool = False, logging_level: int | None = None, max_logging_depth: int = -1, + data: MutableMapping[str, Any] | None = None, ) -> Circuit | tuple[Circuit, PassData]: """Submit a task, wait for its results; see :func:`submit` for more.""" - if isinstance(task_or_circuit, CompilationTask): - warnings.warn( - 'Manually constructing and compiling CompilationTasks' - ' is deprecated. Instead, call compile directly with' - ' your input circuit and workflow. This warning will' - ' turn into an error in a future update.', - DeprecationWarning, - ) - task_id = self.submit( - task_or_circuit, + circuit, workflow, request_data, logging_level, @@ -438,7 +415,12 @@ def _send(self, msg: RuntimeMessage, payload: Any) -> None: except Exception as e: self.conn = None self.close() - raise RuntimeError('Server connection unexpectedly closed.') from e + if isinstance(e, (EOFError, ConnectionResetError)): + raise RuntimeError('Server connection unexpectedly closed.') + else: + raise RuntimeError( + 'Server connection unexpectedly closed.', + ) from e def _send_recv( self, @@ -471,9 +453,15 @@ def _recv_handle_log_error(self) -> tuple[RuntimeMessage, Any]: msg, payload = self.conn.recv() if msg == RuntimeMessage.LOG: - logger = logging.getLogger(payload.name) - if logger.isEnabledFor(payload.levelno): - logger.handle(payload) + record = pickle.loads(payload) + if isinstance(record, logging.LogRecord): + logger = logging.getLogger(record.name) + if logger.isEnabledFor(record.levelno): + logger.handle(record) + else: + name, levelno, msg = record + logger = logging.getLogger(name) + logger.log(levelno, msg) elif msg == RuntimeMessage.ERROR: raise RuntimeError(payload) @@ -530,3 +518,12 @@ def sigint_handler(signum: int, frame: FrameType, compiler: Compiler) -> None: _logger.critical('Compiler interrupted.') compiler.close() raise KeyboardInterrupt + + +_compiler_instances: set[Compiler] = set() + + +@atexit.register +def _cleanup_compiler_instances() -> None: + for compiler in list(_compiler_instances): + compiler.close() diff --git a/bqskit/compiler/gateset.py b/bqskit/compiler/gateset.py index d8111ec4a..be50ce692 100644 --- a/bqskit/compiler/gateset.py +++ b/bqskit/compiler/gateset.py @@ -231,5 +231,9 @@ def __repr__(self) -> str: """Detailed representation of the GateSet.""" return self._gates.__repr__().replace('frozenset', 'GateSet') + def __hash__(self) -> int: + """Hash of the GateSet.""" + return hash(tuple(sorted([g.name for g in self._gates]))) + GateSetLike = Union[GateSet, Iterable[Gate], Gate] diff --git a/bqskit/compiler/passdata.py b/bqskit/compiler/passdata.py index 160d4f44e..23d584712 100644 --- a/bqskit/compiler/passdata.py +++ b/bqskit/compiler/passdata.py @@ -252,6 +252,24 @@ def __contains__(self, _o: object) -> bool: in_data = self._data.__contains__(_o) return in_resv or in_data + def update(self, other: Any = (), /, **kwds: Any) -> None: + """Update the data with key-values pairs from `other` and `kwds`.""" + if isinstance(other, PassData): + for key in other: + # Handle target specially to avoid circuit evaluation + if key == 'target': + self._target = other._target + continue + + self[key] = other[key] + + for key, value in kwds.items(): + self[key] = value + + return + + super().update(other, **kwds) + def copy(self) -> PassData: """Returns a deep copy of the data.""" return copy.deepcopy(self) diff --git a/bqskit/compiler/registry.py b/bqskit/compiler/registry.py new file mode 100644 index 000000000..9a2d5d452 --- /dev/null +++ b/bqskit/compiler/registry.py @@ -0,0 +1,119 @@ +"""Register MachineModel specific default workflows.""" +from __future__ import annotations + +import warnings + +from bqskit.compiler.machine import MachineModel +from bqskit.compiler.workflow import Workflow +from bqskit.compiler.workflow import WorkflowLike + + +_compile_circuit_registry: dict[MachineModel, dict[int, Workflow]] = {} +_compile_unitary_registry: dict[MachineModel, dict[int, Workflow]] = {} +_compile_stateprep_registry: dict[MachineModel, dict[int, Workflow]] = {} +_compile_statemap_registry: dict[MachineModel, dict[int, Workflow]] = {} + + +def model_registered_target_types(key: MachineModel) -> list[str]: + """ + Return a list of target_types for which key is registered. + + Args: + key (MachineModel): A MachineModel to check for. + + Returns: + (list[str]): If `key` has been registered in any of the registry, the + name of that target type will be contained in this list. + """ + global _compile_circuit_registry + global _compile_unitary_registry + global _compile_stateprep_registry + global _compile_statemap_registry + registered_types = [] + if key in _compile_circuit_registry: + registered_types.append('circuit') + if key in _compile_unitary_registry: + registered_types.append('unitary') + if key in _compile_stateprep_registry: + registered_types.append('stateprep') + if key in _compile_statemap_registry: + registered_types.append('statemap') + return registered_types + + +def register_workflow( + key: MachineModel, + workflow: WorkflowLike, + optimization_level: int, + target_type: str = 'circuit', +) -> None: + """ + Register a workflow for a given MachineModel. + + The _compile_registry enables MachineModel specific workflows to be + registered for use in the `bqskit.compile` method. _compile_registry maps + MachineModels a dictionary of Workflows which are indexed by optimization + level. This object should not be accessed directly by the user, but + instead through the `register_workflow` function. + + Args: + key (MachineModel): A MachineModel to register the workflow under. + If a circuit is compiled targeting this machine or gate set, the + registered workflow will be used. + + workflow (list[BasePass]): The workflow or list of passes that will + be executed if the MachineModel in a call to `compile` matches + `key`. If `key` is already registered, a warning will be logged. + + optimization_level (int): The optimization level with which to + register the workflow. + + target_type (str): Register a workflow for targets of this type. Must + be 'circuit', 'unitary', 'stateprep', or 'statemap'. + (Default: 'circuit') + + Example: + model_t = SpecificMachineModel(num_qudits, radixes) + workflow = [QuickPartitioner(3), NewFangledOptimization()] + register_workflow(model_t, workflow, level, 'circuit') + ... + new_circuit = compile(circuit, model_t, optimization_level=level) + + Raises: + Warning: If a workflow for a given optimization_level is overwritten. + + ValueError: If `target_type` is not 'circuit', 'unitary', 'stateprep', + or 'statemap'. + """ + if target_type not in ['circuit', 'unitary', 'stateprep', 'statemap']: + m = 'target_type must be "circuit", "unitary", "stateprep", or ' + m += f'"statemap", got {target_type}.' + raise ValueError(m) + + if target_type == 'circuit': + global _compile_circuit_registry + _compile_registry = _compile_circuit_registry + elif target_type == 'unitary': + global _compile_unitary_registry + _compile_registry = _compile_unitary_registry + elif target_type == 'stateprep': + global _compile_stateprep_registry + _compile_registry = _compile_stateprep_registry + else: + global _compile_statemap_registry + _compile_registry = _compile_statemap_registry + + workflow = Workflow(workflow) + + new_workflow = {optimization_level: workflow} + if key in _compile_registry: + if optimization_level in _compile_registry[key]: + m = f'Overwritting workflow for {key} at level ' + m += f'{optimization_level} for target type {target_type}.' + m += 'If multiple Namespace packages are installed, ensure' + m += 'that their __init__.py files do not attempt to' + m += 'overwrite the same default Workflows.' + warnings.warn(m) + _compile_registry[key].update(new_workflow) + else: + _compile_registry[key] = new_workflow diff --git a/bqskit/compiler/workflow.py b/bqskit/compiler/workflow.py index 6134d07aa..5b7f103df 100644 --- a/bqskit/compiler/workflow.py +++ b/bqskit/compiler/workflow.py @@ -10,6 +10,8 @@ from typing import TYPE_CHECKING from typing import Union +import dill + from bqskit.compiler.basepass import BasePass from bqskit.utils.random import seed_random_sources from bqskit.utils.typing import is_iterable @@ -39,6 +41,7 @@ def __init__(self, passes: WorkflowLike, name: str = '') -> None: """ if isinstance(passes, Workflow): self._passes: list[BasePass] = copy.deepcopy(passes._passes) + self._name: str = name if name else copy.deepcopy(passes._name) return if isinstance(passes, BasePass): @@ -87,6 +90,12 @@ def name(self) -> str: """The name of the pass.""" return self._name or self.__class__.__name__ + @staticmethod + def is_workflow(workflow: WorkflowLike) -> bool: + if not is_iterable(workflow): + return isinstance(workflow, BasePass) + return all(isinstance(p, BasePass) for p in workflow) + def __str__(self) -> str: name_seq = f'Workflow: {self.name}\n\t' pass_strs = [ @@ -118,5 +127,11 @@ def __getitem__(self, _key: slice, /) -> list[BasePass]: def __getitem__(self, _key: int | slice) -> BasePass | list[BasePass]: return self._passes.__getitem__(_key) + def __getstate__(self) -> bytes: + return dill.dumps(self.__dict__, recurse=True) + + def __setstate__(self, state: bytes) -> None: + self.__dict__.update(dill.loads(state)) + WorkflowLike = Union[Workflow, Iterable[BasePass], BasePass] diff --git a/bqskit/ext/__init__.py b/bqskit/ext/__init__.py index a6a4ee456..d2381085c 100644 --- a/bqskit/ext/__init__.py +++ b/bqskit/ext/__init__.py @@ -18,8 +18,11 @@ Aspen11Model AspenM2Model + ANKAA2Model + ANKAA9Q3Model H1_1Model H1_2Model + H2_1Model Sycamore23Model SycamoreModel model_from_backend @@ -64,8 +67,11 @@ from bqskit.ext.qiskit.translate import qiskit_to_bqskit from bqskit.ext.quantinuum import H1_1Model from bqskit.ext.quantinuum import H1_2Model +from bqskit.ext.quantinuum import H2_1Model from bqskit.ext.qutip.translate import bqskit_to_qutip from bqskit.ext.qutip.translate import qutip_to_bqskit +from bqskit.ext.rigetti import ANKAA2Model +from bqskit.ext.rigetti import ANKAA9Q3Model from bqskit.ext.rigetti import Aspen11Model from bqskit.ext.rigetti import AspenM2Model from bqskit.ext.supermarq import supermarq_critical_depth @@ -73,6 +79,7 @@ from bqskit.ext.supermarq import supermarq_liveness from bqskit.ext.supermarq import supermarq_parallelism from bqskit.ext.supermarq import supermarq_program_communication +# TODO: Deprecate imports from __init__, use lazy import to deprecate __all__ = [ @@ -94,6 +101,9 @@ 'AspenM2Model', 'H1_1Model', 'H1_2Model', + 'H2_1Model', + 'ANKAA2Model', + 'ANKAA9Q3Model', 'Sycamore23Model', 'SycamoreModel', ] diff --git a/bqskit/ext/honeywell.py b/bqskit/ext/honeywell.py deleted file mode 100644 index af0aa7537..000000000 --- a/bqskit/ext/honeywell.py +++ /dev/null @@ -1,23 +0,0 @@ -"""This module implemenets Honeywell QPU models.""" -from __future__ import annotations - -import warnings - -from bqskit.compiler.machine import MachineModel -from bqskit.ir.gate import Gate -from bqskit.ir.gates.constant.zz import ZZGate -from bqskit.ir.gates.parameterized.rz import RZGate -from bqskit.ir.gates.parameterized.u1q import U1qPi2Gate -from bqskit.ir.gates.parameterized.u1q import U1qPiGate - -warnings.warn( - 'Honeywell Quantum is now Quantinuum. Please use the ' - 'Quantinuum QPU models and gate sets instead. This warning will become' - 'an error in a future version of BQSKit.', - DeprecationWarning, -) - -honeywell_gate_set: set[Gate] = {U1qPiGate, U1qPi2Gate, RZGate(), ZZGate()} - -H1_1Model = MachineModel(20, None, honeywell_gate_set) -H1_2Model = MachineModel(12, None, honeywell_gate_set) diff --git a/bqskit/ext/qiskit/models.py b/bqskit/ext/qiskit/models.py index 237771390..9b42d57f3 100644 --- a/bqskit/ext/qiskit/models.py +++ b/bqskit/ext/qiskit/models.py @@ -23,7 +23,7 @@ def model_from_backend(backend: BackendV1) -> MachineModel: num_qudits = config.n_qubits gate_set = _basis_gate_str_to_bqskit_gate(config.basis_gates) coupling_map = list({tuple(sorted(e)) for e in config.coupling_map}) - return MachineModel(num_qudits, coupling_map, gate_set) # type: ignore + return MachineModel(num_qudits, coupling_map, gate_set) def _basis_gate_str_to_bqskit_gate(basis_gates: list[str]) -> set[Gate]: diff --git a/bqskit/ext/quantinuum.py b/bqskit/ext/quantinuum.py index a24a062de..948fe380f 100644 --- a/bqskit/ext/quantinuum.py +++ b/bqskit/ext/quantinuum.py @@ -3,12 +3,13 @@ from bqskit.compiler.machine import MachineModel from bqskit.ir.gate import Gate -from bqskit.ir.gates.constant.zz import ZZGate from bqskit.ir.gates.parameterized.rz import RZGate +from bqskit.ir.gates.parameterized.rzz import RZZGate from bqskit.ir.gates.parameterized.u1q import U1qPi2Gate from bqskit.ir.gates.parameterized.u1q import U1qPiGate -quantinuum_gate_set: set[Gate] = {U1qPiGate, U1qPi2Gate, RZGate(), ZZGate()} +quantinuum_gate_set: set[Gate] = {U1qPiGate, U1qPi2Gate, RZGate(), RZZGate()} H1_1Model = MachineModel(20, None, quantinuum_gate_set) H1_2Model = MachineModel(20, None, quantinuum_gate_set) +H2_1Model = MachineModel(56, None, quantinuum_gate_set) diff --git a/bqskit/ext/rigetti.py b/bqskit/ext/rigetti.py index 15bba867a..2b2c322cf 100644 --- a/bqskit/ext/rigetti.py +++ b/bqskit/ext/rigetti.py @@ -3,6 +3,7 @@ from bqskit.compiler.machine import MachineModel from bqskit.ir.gates.constant.cz import CZGate +from bqskit.ir.gates.constant.iswap import ISwapGate from bqskit.ir.gates.constant.sx import SXGate from bqskit.ir.gates.constant.x import XGate from bqskit.ir.gates.parameterized.rz import RZGate @@ -10,6 +11,8 @@ rigetti_gate_set = {SXGate(), XGate(), RZGate(), CZGate()} +ankaa_gate_set = {SXGate(), XGate(), RZGate(), CZGate(), ISwapGate()} + _aspen_11_coupling_graph = CouplingGraph([ # Ring 1 (0, 1), (1, 2), (2, 3), (3, 4), @@ -79,10 +82,20 @@ _aspen_m2_coupling_graph = CouplingGraph(_links) """Retrieved August 31, 2022: https://qcs.rigetti.com/qpus.""" +_ankaa_9q_3_coupling_graph = CouplingGraph.grid(3, 3) +"""Retrieved September 11, 2024: https://qcs.rigetti.com/qpus.""" + +_ankaa_2_coupling_graph = CouplingGraph.grid(7, 12) +"""Retrieved September 11, 2024: https://qcs.rigetti.com/qpus.""" + Aspen11Model = MachineModel(40, _aspen_11_coupling_graph, rigetti_gate_set) """A BQSKit MachineModel for Rigetti's Aspen-11 quantum processor.""" AspenM2Model = MachineModel(80, _aspen_m2_coupling_graph, rigetti_gate_set) """A BQSKit MachineModel for Rigetti's Aspen-M-2 quantum processor.""" +ANKAA2Model = MachineModel(84, _ankaa_2_coupling_graph, ankaa_gate_set) + +ANKAA9Q3Model = MachineModel(9, _ankaa_9q_3_coupling_graph, ankaa_gate_set) + __all__ = ['Aspen11Model', 'AspenM2Model'] diff --git a/bqskit/ir/__init__.py b/bqskit/ir/__init__.py index 10d0e4342..9959e3f04 100644 --- a/bqskit/ir/__init__.py +++ b/bqskit/ir/__init__.py @@ -62,6 +62,8 @@ from bqskit.ir.interval import CycleInterval from bqskit.ir.interval import IntervalLike from bqskit.ir.iterator import CircuitIterator +from bqskit.ir.lang import register_language as _register_language +from bqskit.ir.lang.qasm2 import OPENQASM2Language as _qasm from bqskit.ir.location import CircuitLocation from bqskit.ir.location import CircuitLocationLike from bqskit.ir.operation import Operation @@ -71,6 +73,11 @@ from bqskit.ir.region import CircuitRegionLike from bqskit.ir.structure import CircuitStructure + +# Register supported languages +_register_language('qasm', _qasm()) + + __all__ = [ 'Operation', 'Circuit', diff --git a/bqskit/ir/circuit.py b/bqskit/ir/circuit.py index d58b299a7..3c7fed2b7 100644 --- a/bqskit/ir/circuit.py +++ b/bqskit/ir/circuit.py @@ -3,21 +3,21 @@ import copy import logging +import pickle import warnings from typing import Any +from typing import Callable from typing import cast from typing import Collection from typing import Dict from typing import Iterable from typing import Iterator -from typing import List from typing import Optional from typing import overload from typing import Sequence -from typing import Set -from typing import Tuple from typing import TYPE_CHECKING +import dill import numpy as np import numpy.typing as npt @@ -25,6 +25,7 @@ from bqskit.ir.gates.circuitgate import CircuitGate from bqskit.ir.gates.constant.unitary import ConstantUnitaryGate from bqskit.ir.gates.measure import MeasurementPlaceholder +from bqskit.ir.interval import CycleInterval from bqskit.ir.iterator import CircuitIterator from bqskit.ir.lang import get_language from bqskit.ir.location import CircuitLocation @@ -902,13 +903,53 @@ def first_on(self, qudit: int) -> CircuitPoint | None: """Report the point for the first operation on `qudit` if it exists.""" return self._front[qudit] - def next(self, point: CircuitPoint) -> set[CircuitPoint]: + def next( + self, + point: CircuitPointLike | CircuitRegionLike, + ) -> set[CircuitPoint]: """Return the points of operations dependent on the one at `point`.""" - return {p for p in self._dag[point][1].values() if p is not None} + if CircuitRegion.is_region(point): + points = [] + for cyc_op in self.operations_with_cycles(qudits_or_region=point): + points.append((cyc_op[0], cyc_op[1].location[0])) + + next_points = set() + for p in points: + for next in self.next(p): + if next not in points: + next_points.add(next) + + return next_points + + return { + p + for p in self._dag[point][1].values() # type: ignore + if p is not None + } - def prev(self, point: CircuitPoint) -> set[CircuitPoint]: + def prev( + self, + point: CircuitPointLike | CircuitRegionLike, + ) -> set[CircuitPoint]: """Return the points of operations the one at `point` depends on.""" - return {p for p in self._dag[point][0].values() if p is not None} + if CircuitRegion.is_region(point): + points = [] + for cyc_op in self.operations_with_cycles(qudits_or_region=point): + points.append((cyc_op[0], cyc_op[1].location[0])) + + prev_points = set() + for p in points: + for prev in self.prev(p): + if prev not in points: + prev_points.add(prev) + + return prev_points + + return { + p + for p in self._dag[point][0].values() # type: ignore + if p is not None + } # endregion @@ -1035,33 +1076,8 @@ def point( raise ValueError('No such operation exists in the circuit.') - def append(self, op: Operation) -> int: - """ - Append `op` to the end of the circuit and return its cycle index. - - Args: - op (Operation): The operation to append. - - Returns: - int: The cycle index of the appended operation. - - Raises: - ValueError: If `op` cannot be placed on the circuit due to - either an invalid location or gate radix mismatch. - - Notes: - Due to the circuit being represented as a matrix, - `circuit.append(op)` does not imply `op` is last in simulation - order but it implies `op` is in the last cycle of circuit. - - Examples: - >>> from bqskit.ir.gates import HGate - >>> circ = Circuit(1) - >>> op = Operation(HGate(), [0]) - >>> circ.append(op) # Appends a Hadamard gate to qudit 0. - """ - self.check_valid_operation(op) - cycle_index = self._find_available_or_append_cycle(op.location) + def _append(self, op: Operation, cycle_index: int) -> None: + """Append the operation to the circuit at the specified cycle.""" point = CircuitPoint(cycle_index, op.location[0]) prevs: dict[int, CircuitPoint | None] = {i: None for i in op.location} @@ -1096,6 +1112,34 @@ def append(self, op: Operation) -> int: self._gate_info[op.gate] = 0 self._gate_info[op.gate] += 1 + def append(self, op: Operation) -> int: + """ + Append `op` to the end of the circuit and return its cycle index. + + Args: + op (Operation): The operation to append. + + Returns: + int: The cycle index of the appended operation. + + Raises: + ValueError: If `op` cannot be placed on the circuit due to + either an invalid location or gate radix mismatch. + + Notes: + Due to the circuit being represented as a matrix, + `circuit.append(op)` does not imply `op` is last in simulation + order but it implies `op` is in the last cycle of circuit. + + Examples: + >>> from bqskit.ir.gates import HGate + >>> circ = Circuit(1) + >>> op = Operation(HGate(), [0]) + >>> circ.append(op) # Appends a Hadamard gate to qudit 0. + """ + self.check_valid_operation(op) + cycle_index = self._find_available_or_append_cycle(op.location) + self._append(op, cycle_index) return cycle_index def append_gate( @@ -1853,6 +1897,10 @@ def check_region( """ Check `region` to be a valid in the context of this circuit. + A CircuitRegion is valid if it is within the bounds of the circuit + and for every pair of operations in the region, there is no path + between them that exits the region. + Args: region (CircuitRegionLike): The region to check. @@ -1883,35 +1931,49 @@ def check_region( f"but region's maximum cycle is {region.max_cycle}.", ) - for qudit_index, cycle_intervals in region.items(): - for other_qudit_index, other_cycle_intervals in region.items(): - if cycle_intervals.overlaps(other_cycle_intervals): + if strict: + for qudit_index, cycle_intervals in region.items(): + for other_qudit_index, other_cycle_intervals in region.items(): + if not cycle_intervals.overlaps(other_cycle_intervals): + raise ValueError('Disconnect detected in region.') + + cycles_ops = self.operations_with_cycles( + qudits_or_region=region, exclude=True, + ) + points = [(cop[0], cop[1].location[0]) for cop in cycles_ops] + known_to_never_reenter = set() + + # Walk back from max cycle + for pt in sorted(points, key=lambda x: x[0], reverse=True): + + # Max cycle is valid in base case + if pt[0] == region.max_cycle: + continue + + frontier = self.next(pt) + while frontier: + pt2 = frontier.pop() + + # Walk only paths that exit the region + if pt2 in points: continue - involved_qudits = {qudit_index} - min_index = min( - cycle_intervals.upper, - other_cycle_intervals.upper, - ) - max_index = max( - cycle_intervals.lower, - other_cycle_intervals.lower, - ) - for cycle_index in range(min_index + 1, max_index): - try: - ops = self[cycle_index, involved_qudits] - except IndexError: - continue - if strict: - raise ValueError('Disconnect detected in region.') + # Stop walking after the max cycle + if pt2[0] >= region.max_cycle: + continue - if any(other_qudit_index in op.location for op in ops): - raise ValueError( - 'Disconnected region has excluded gate in middle.', - ) + # Skip this point if previously determined it to be good + if pt2 in known_to_never_reenter: + continue + + expansion = self.next(pt2) - for op in ops: - involved_qudits.update(op.location) + # If there is a path that re-enters the region, fail + if any(p in points for p in expansion): + raise ValueError('Disconnect detected in region.') + + frontier.update(expansion) + known_to_never_reenter.add(pt2) def straighten( self, @@ -2092,27 +2154,46 @@ def unfold_all(self) -> None: def surround( self, - point: CircuitPointLike, + point: CircuitPointLike | CircuitRegionLike, num_qudits: int, bounding_region: CircuitRegionLike | None = None, - fail_quickly: bool = False, + fail_quickly: bool | None = None, + filter: Callable[[CircuitRegion], bool] | None = None, + scoring_fn: Callable[[CircuitRegion], float] | None = None, ) -> CircuitRegion: """ - Retrieve the maximal region in this circuit with `point` included. + Retrieve the maximal connected region in this circuit with `point`. Args: - point (CircuitPointLike): Find a surrounding region for this - point. This point will be in the final CircuitRegion. + point (CircuitPointLike | CircuitRegionLike): Find a surrounding + region for this point (or region). This point (or region) + will be in the final CircuitRegion. - num_qudits (int): The number of qudits to include in the region. + num_qudits (int): The maximum number of qudits to include in + the surrounding region. bounding_region (CircuitRegionLike | None): An optional region that bounds the resulting region. - fail_quickly (bool): If set to true, will not branch on + fail_quickly (bool | None): If set to true, will not branch on an invalid region. This will lead to a much faster result in some cases at the cost of only approximating - the maximal region. + the maximal region. (Deprecated, does nothing now besides + print a warning if a bool.) + + filter (Callable[[CircuitRegion], bool] | None): The filter + function determines if a candidate region is valid in the + caller's context. This is used to prune the search space + of the surround function. If None, then no filtering is + done. It takes a CircuitRegion and returns a boolean. + Only regions that pass the filter are considered. + + scoring_fn (Callable[[CircuitRegion], float] | None): The + scoring function determines the "best" surrounding region. + If left as None, then this will default to the region with + the most number of gates with larger gates worth more. + It takes a CircuitRegion and returns a float. Larger scores + are better. Raises: IndexError: If `point` is not a valid index. @@ -2124,6 +2205,8 @@ def surround( ValueError: If `bounding_region` is invalid. + ValueError: If the initial node does not pass the filter. + Notes: This algorithm explores outward horizontally as much as possible. When a gate is encountered that involves another qudit not @@ -2143,193 +2226,151 @@ def surround( f'Expected a positive integer num_qudits, got {num_qudits}.', ) - if bounding_region is not None: - bounding_region = CircuitRegion(bounding_region) + if filter is not None and not callable(filter): + raise TypeError(f'Expected callable filter, got {type(filter)}.') - point = self.normalize_point(point) + def default_scoring_fn(region: CircuitRegion) -> float: + return float(sum(op.num_qudits * 100 for op in self[region])) - init_op: Operation = self[point] # Allow starting at an idle point + if scoring_fn is None: + scoring_fn = default_scoring_fn - if init_op.num_qudits > num_qudits: - raise ValueError('Gate at point is too large for num_qudits.') + if not callable(scoring_fn): + raise TypeError( + f'Expected callable scoring_fn, got {type(scoring_fn)}.', + ) - HalfWire = Tuple[CircuitPoint, str] - """ - A HalfWire is a point in the circuit and a direction. + if fail_quickly is not None: + warnings.warn( + 'The fail_quickly argument is deprecated and does nothing. ' + 'Surround will always attempt to find the maximal region. ' + 'This argument will be removed in a future release and this ' + 'warning will become an error.', + DeprecationWarning, + ) - This represents a point to start exploring from and a direction to - explore in. - """ + if bounding_region is not None: + bounding_region = CircuitRegion(bounding_region) - Node = Tuple[ - List[HalfWire], - Set[Tuple[int, Operation]], - CircuitLocation, - Set[CircuitPoint], - ] - """ - A Node in the search tree. + if CircuitPoint.is_point(point): + if self.is_point_idle(point): + init_region = CircuitRegion({point[1]: (point[0], point[0])}) + else: + init_region = self.get_region([point]) + elif CircuitRegion.is_region(point): + init_region = CircuitRegion(point) + else: + raise TypeError( + f'Expected CircuitPoint or CircuitRegion, got {type(point)}.', + ) - Each node represents a region that may grow further. The data structure - tracks all HalfWires in the region and the set of operations inside the - region. During node exploration each HalfWire is walked until we find a - multi-qudit gate. Multi- qudit gates form branches in the tree on - whether on the gate should be included. The node structure additionally - stores the set of qudit indices involved in the region currently. Also, - we track points that have already been explored to reduce repetition. - """ + if init_region.num_qudits > num_qudits: + raise ValueError('Initial region is too large for num_qudits.') - # Initialize the frontier - init_node = ( - [ - (CircuitPoint(point[0], qudit_index), 'left') - for qudit_index in init_op.location - ] - + [ - (CircuitPoint(point[0], qudit_index), 'right') - for qudit_index in init_op.location - ], - {(point[0], init_op)}, - init_op.location, - {CircuitPoint(point[0], q) for q in init_op.location}, - ) + if filter is not None and not filter(init_region): + raise ValueError('Initial region does not pass filter.') - frontier: list[Node] = [init_node] + # Initialize Search + frontier: list[CircuitRegion] = [init_region] + seen: set[CircuitRegion] = set() # Track best so far - def score(node: Node) -> int: - return sum(op[1].num_qudits for op in node[1]) - - best_score = score(init_node) - best_region = self.get_region({(point[0], init_op.location[0])}) + best_score = (scoring_fn(init_region), init_region.num_qudits) + best_region = init_region # Exhaustive Search while len(frontier) > 0: node = frontier.pop(0) - _logger.debug('popped node:') - _logger.debug(node[0]) - _logger.debug(f'Items remaining in the frontier: {len(frontier)}') # Evaluate node - if score(node) > best_score: - # Calculate region from best node and return - points = {(cycle, op.location[0]) for cycle, op in node[1]} - - try: - best_region = self.get_region(points) - best_score = score(node) - _logger.debug(f'new best: {best_region}.') - - # Need to reject bad regions - except ValueError: - if fail_quickly: - continue + node_score = (scoring_fn(node), node.num_qudits) + if node_score > best_score: + best_region = node + best_score = node_score # Expand node - absorbed_gates: set[tuple[int, Operation]] = set() - branches: set[tuple[int, int, Operation]] = set() - before_branch_half_wires: dict[int, HalfWire] = {} - for i, half_wire in enumerate(node[0]): - - cycle_index, qudit_index = half_wire[0] - step = -1 if half_wire[1] == 'left' else 1 - - while True: - - # Take a step - cycle_index += step - - # Stop at edges - if cycle_index < 0 or cycle_index >= self.num_cycles: - break - - # Stop when outside bounds - if bounding_region is not None: - if (cycle_index, qudit_index) not in bounding_region: + for point in self.next(node).union(self.prev(node)): + # Create new region by adding the gate at this point + region_bldr = {k: v for k, v in node.items()} + op = self[point] + valid_region = True + need_to_fully_check = False + for qudit in op.location: + if qudit not in region_bldr: + region_bldr[qudit] = CycleInterval(point[0], point[0]) + need_to_fully_check = True + + elif point[0] < region_bldr[qudit][0]: + # Check for gates in the middle not in region + if any( + not self.is_point_idle((i, qudit)) + for i in range(point[0] + 1, region_bldr[qudit][0]) + ): + valid_region = False break - # Stop when exploring previously explored points - point = CircuitPoint(cycle_index, qudit_index) - if point in node[3]: - break - node[3].add(point) + # Absorb Single-qudit gates + index = point[0] + while index > 0: + if not self.is_point_idle((index - 1, qudit)): + prev_op = self[index - 1, qudit] + if len(prev_op.location) != 1: + break + index -= 1 + + region_bldr[qudit] = CycleInterval( + index, + region_bldr[qudit][1], + ) - # Continue until next operation - if self.is_point_idle(point): - continue - op: Operation = self[cycle_index, qudit_index] + elif point[0] > region_bldr[qudit][1]: + # Check for gates in the middle not in region + if any( + not self.is_point_idle((i, qudit)) + for i in range(region_bldr[qudit][1] + 1, point[0]) + ): + valid_region = False + break - # Gates already in region stop the half_wire - if (cycle_index, op) in node[1]: - break + # Absorb Single-qudit gates + index = point[0] + while index < self.num_cycles - 1: + if not self.is_point_idle((index + 1, qudit)): + next_op = self[index + 1, qudit] + if len(next_op.location) != 1: + break + index += 1 + + region_bldr[qudit] = CycleInterval( + region_bldr[qudit][0], + index, + ) - # Gates already accounted for stop the half_wire - if (cycle_index, op) in absorbed_gates: - break + # Discard too large regions + if len(region_bldr) > num_qudits: + continue - if (cycle_index, op) in [(c, o) for h, c, o in branches]: - break + # Discard invalid regions + if not valid_region: + continue - # Absorb single-qudit gates - if len(op.location) == 1: - absorbed_gates.add((cycle_index, op)) + if need_to_fully_check: + if not self.is_valid_region(region_bldr): continue - # Operations that are too large stop the half_wire - if len(op.location.union(node[2])) > num_qudits: - break + new_region = CircuitRegion(region_bldr) - # Otherwise branch on the operation - branches.add((i, cycle_index, op)) + # Check uniqueness + if new_region in seen: + continue - # Track state of half wire right before branch - prev_point = CircuitPoint(cycle_index - step, qudit_index) - before_branch_half_wires[i] = (prev_point, half_wire[1]) - break + # Check filter + if filter is not None and not filter(new_region): + continue - # Compute children and extend frontier - for half_wire_index, cycle_index, op in branches: - - child_half_wires = [ - half_wire - for i, half_wire in before_branch_half_wires.items() - if half_wire_index != i - ] - - qudit = node[0][half_wire_index][0].qudit - direction = node[0][half_wire_index][1] - left_expansion = [ - (CircuitPoint(cycle_index, qudit_index), 'left') - for qudit_index in op.location - if qudit != qudit_index or direction == 'left' - ] - right_expansion = [ - (CircuitPoint(cycle_index, qudit_index), 'right') - for qudit_index in op.location - if qudit != qudit_index or direction == 'right' - ] - expansion = left_expansion + right_expansion - - # Branch/Gate not taken - frontier.append(( - child_half_wires, - node[1] | absorbed_gates, - node[2], - node[3], - )) - - # Branch/Gate taken - op_points = {CircuitPoint(cycle_index, q) for q in op.location} - frontier.append(( - list(set(child_half_wires + expansion)), - node[1] | absorbed_gates | {(cycle_index, op)}, - node[2].union(op.location), - node[3] | op_points, - )) - - # Append terminal node to handle absorbed gates with no branches - if len(node[1] | absorbed_gates) != len(node[1]): - frontier.append(([], node[1] | absorbed_gates, *node[2:])) + # Expand frontier + frontier.append(new_region) + seen.add(new_region) return best_region @@ -2717,17 +2758,9 @@ def perform( :class:`~bqskit.compiler.compiler.Compiler` directly. """ from bqskit.compiler.compiler import Compiler - from bqskit.compiler.passdata import PassData - from bqskit.compiler.task import CompilationTask - - pass_data = PassData(self) - if data is not None: - pass_data.update(data) with Compiler() as compiler: - task = CompilationTask(self, [compiler_pass]) - task.data = pass_data - task_id = compiler.submit(task) + task_id = compiler.submit(self, [compiler_pass], data=data) self.become(compiler.result(task_id)) # type: ignore def instantiate( @@ -3241,4 +3274,73 @@ def from_operation(op: Operation) -> Circuit: circuit.append_gate(op.gate, list(range(circuit.num_qudits)), op.params) return circuit + def __reduce__(self) -> tuple[ + Callable[ + [int, tuple[int, ...], list[tuple[bool, bytes]], bytes], + Circuit, + ], + tuple[int, tuple[int, ...], list[tuple[bool, bytes]], bytes], + ]: + """Return the pickle state of the circuit.""" + serialized_gates: list[tuple[bool, bytes]] = [] + gate_table = {} + for gate in self.gate_set: + gate_table[gate] = len(serialized_gates) + if gate.__class__.__module__.startswith('bqskit'): + serialized_gates.append((False, pickle.dumps(gate))) + else: + serialized_gates.append((True, dill.dumps(gate, recurse=True))) + + cycles: list[list[tuple[int, tuple[int, ...], list[float]]]] = [] + last_cycle = -1 + for cycle, op in self.operations_with_cycles(): + + if cycle != last_cycle: + last_cycle = cycle + cycles.append([]) + + marshalled_op = ( + gate_table[op.gate], + op.location._location, + op.params, + ) + cycles[-1].append(marshalled_op) + + data = ( + self.num_qudits, + self.radixes, + serialized_gates, + pickle.dumps(cycles), + ) + return (rebuild_circuit, data) + # endregion + + +def rebuild_circuit( + num_qudits: int, + radixes: tuple[int, ...], + serialized_gates: list[tuple[bool, bytes]], + serialized_cycles: bytes, +) -> Circuit: + """Rebuild a circuit from a pickle state.""" + circuit = Circuit(num_qudits, radixes) + + gate_table = {} + for i, (is_dill, serialized_gate) in enumerate(serialized_gates): + if is_dill: + gate = dill.loads(serialized_gate) + else: + gate = pickle.loads(serialized_gate) + gate_table[i] = gate + + cycles = pickle.loads(serialized_cycles) + for i, cycle in enumerate(cycles): + circuit._append_cycle() + for marshalled_op in cycle: + gate = gate_table[marshalled_op[0]] + location = marshalled_op[1] + params = marshalled_op[2] + circuit._append(Operation(gate, location, params), i) + + return circuit diff --git a/bqskit/ir/gates/__init__.py b/bqskit/ir/gates/__init__.py index 1c9a83cbb..183a6426c 100644 --- a/bqskit/ir/gates/__init__.py +++ b/bqskit/ir/gates/__init__.py @@ -73,6 +73,7 @@ CUGate FSIMGate PauliGate + PauliZGate PhasedXZGate RSU3Gate RXGate @@ -99,6 +100,7 @@ :template: autosummary/gate.rst ControlledGate + PowerGate DaggerGate EmbeddedGate FrozenParameterGate diff --git a/bqskit/ir/gates/circuitgate.py b/bqskit/ir/gates/circuitgate.py index 8c33c4b52..d870b7931 100644 --- a/bqskit/ir/gates/circuitgate.py +++ b/bqskit/ir/gates/circuitgate.py @@ -137,6 +137,7 @@ def get_qasm_gate_def(self) -> str: ', '.join([str(p) for p in params]), ', q'.join([str(q) for q in op.location]), ).replace('()', '') + param_index += op.num_params ret += '}\n' return ret diff --git a/bqskit/ir/gates/composed/__init__.py b/bqskit/ir/gates/composed/__init__.py index d1f065f34..3eb5ac849 100644 --- a/bqskit/ir/gates/composed/__init__.py +++ b/bqskit/ir/gates/composed/__init__.py @@ -5,11 +5,13 @@ from bqskit.ir.gates.composed.daggergate import DaggerGate from bqskit.ir.gates.composed.embedded import EmbeddedGate from bqskit.ir.gates.composed.frozenparam import FrozenParameterGate +from bqskit.ir.gates.composed.powergate import PowerGate from bqskit.ir.gates.composed.tagged import TaggedGate from bqskit.ir.gates.composed.vlg import VariableLocationGate __all__ = [ 'ControlledGate', + 'PowerGate', 'DaggerGate', 'EmbeddedGate', 'FrozenParameterGate', diff --git a/bqskit/ir/gates/composed/powergate.py b/bqskit/ir/gates/composed/powergate.py new file mode 100644 index 000000000..30d489099 --- /dev/null +++ b/bqskit/ir/gates/composed/powergate.py @@ -0,0 +1,155 @@ +"""This module implements the DaggerGate Class.""" +from __future__ import annotations + +import re + +import numpy as np +import numpy.typing as npt + +from bqskit.ir.gate import Gate +from bqskit.ir.gates.composed.daggergate import DaggerGate +from bqskit.ir.gates.composedgate import ComposedGate +from bqskit.qis.unitary.differentiable import DifferentiableUnitary +from bqskit.qis.unitary.unitary import RealVector +from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix +from bqskit.utils.docs import building_docs +from bqskit.utils.typing import is_integer + + +class PowerGate( + ComposedGate, + DifferentiableUnitary, +): + """ + An arbitrary inverted gate. + + The PowerGate is a composed gate that equivalent to the + integer power of the input gate. + + Examples: + >>> from bqskit.ir.gates import TGate, TdgGate + >>> PowerGate(TGate(),2).get_unitary() == + TdgGate().get_unitary()*TdgGate().get_unitary() + True + """ + + def __init__(self, gate: Gate, power: int = 1) -> None: + """ + Create a gate which is the integer power of the input gate. + + Args: + gate (Gate): The Gate to conjugate transpose. + power (int): The power index for the PowerGate. + """ + if not isinstance(gate, Gate): + raise TypeError('Expected gate object, got %s' % type(gate)) + + if not is_integer(power): + raise TypeError(f'Expected integer power, got {type(power)}.') + + self.gate = gate + self.power = power + self._name = f'[{gate.name}^{power}]' + self._num_params = gate.num_params + self._num_qudits = gate.num_qudits + self._radixes = gate.radixes + + # If input is a constant gate, we can cache the unitary. + if self.num_params == 0 and not building_docs(): + self.utry = self.gate.get_unitary([]).ipower(power) + + def get_unitary(self, params: RealVector = []) -> UnitaryMatrix: + """Return the unitary for this gate, see :class:`Unitary` for more.""" + if hasattr(self, 'utry'): + return self.utry + + return self.gate.get_unitary(params).ipower(self.power) + + def get_grad(self, params: RealVector = []) -> npt.NDArray[np.complex128]: + """ + Return the gradient for this gate. + + See :class:`DifferentiableUnitary` for more info. + + Notes: + The derivative of the integer power of matrix is equal + to the derivative of the matrix multiplied by + the integer-1 power of the matrix + and by the integer power. + """ + if hasattr(self, 'utry'): + return np.array([]) + + _, grad = self.get_unitary_and_grad(params) + return grad + + def get_unitary_and_grad( + self, + params: RealVector = [], + ) -> tuple[UnitaryMatrix, npt.NDArray[np.complex128]]: + """ + Return the unitary and gradient for this gate. + + See :class:`DifferentiableUnitary` for more info. + """ + # Constant gate case + if hasattr(self, 'utry'): + return self.utry, np.array([]) + + grad_shape = (self.num_params, self.dim, self.dim) + + # Identity gate case + if self.power == 0: + utry = UnitaryMatrix.identity(self.dim) + grad = np.zeros(grad_shape, dtype=np.complex128) + return utry, grad + + # Invert the gate if the power is negative + gate = self.gate if self.power > 0 else DaggerGate(self.gate) + power = abs(self.power) + + # Parallel Dicts for unitary and gradient powers + utrys = {} # utrys[i] = gate^(2^i) + grads = {} # grads[i] = d(gate^(2^i))/d(params) + + # decompose the power as sum of powers of 2 + power_bin = bin(abs(power))[2:] + binary_decomp = [ + len(power_bin) - 1 - xb.start() + for xb in re.finditer('1', power_bin) + ][::-1] + max_power_of_2 = max(binary_decomp) + + # Base Case: 2^0 + utrys[0], grads[0] = gate.get_unitary_and_grad(params) # type: ignore + + # Loop over powers of 2 + for i in range(1, max_power_of_2 + 1): + # u^(2^i) = u^(2^(i-1)) @ u^(2^(i-1)) + utrys[i] = utrys[i - 1] @ utrys[i - 1] + + # d[u^(2^i)] = d[u^(2^(i-1)) @ u^(2^(i-1))] = + grads[i] = grads[i - 1] @ utrys[i - 1] + utrys[i - 1] @ grads[i - 1] + + # Calculate binary composition of the unitary and gradient + utry = utrys[binary_decomp[0]] + grad = grads[binary_decomp[0]] + for i in sorted(binary_decomp[1:]): + grad = grad @ utrys[i] + utry @ grads[i] + utry = utry @ utrys[i] + + return utry, grad + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, PowerGate) + and self.gate == other.gate + and self.power == other.power + ) + + def __hash__(self) -> int: + return hash((self.power, self.gate)) + + def get_inverse(self) -> Gate: + """Return the gate's inverse as a gate.""" + return PowerGate(self.gate, -self.power) diff --git a/bqskit/ir/gates/constant/h.py b/bqskit/ir/gates/constant/h.py index 633db3be2..34fafa34d 100644 --- a/bqskit/ir/gates/constant/h.py +++ b/bqskit/ir/gates/constant/h.py @@ -1,7 +1,13 @@ """This module implements the HGate.""" from __future__ import annotations -import numpy as np +from math import pi +from math import sqrt + +from numpy import array +from numpy import complex128 +from numpy import exp +from numpy import zeros from bqskit.ir.gates.constantgate import ConstantGate from bqskit.ir.gates.quditgate import QuditGate @@ -67,22 +73,22 @@ def __init__(self, radix: int = 2) -> None: # Calculate unitary if radix == 2: - matrix = np.array( + matrix = array( [ - [np.sqrt(2) / 2, np.sqrt(2) / 2], - [np.sqrt(2) / 2, -np.sqrt(2) / 2], + [sqrt(2) / 2, sqrt(2) / 2], + [sqrt(2) / 2, -sqrt(2) / 2], ], - dtype=np.complex128, + dtype=complex128, ) self._utry = UnitaryMatrix(matrix) else: - matrix = np.zeros([radix] * 2, dtype=np.complex128) - omega = np.exp(2j * np.pi / radix) + matrix = zeros([radix] * 2, dtype=complex128) + omega = exp(2j * pi / radix) for i in range(radix): for j in range(i, radix): val = omega ** (i * j) matrix[i, j] = val matrix[j, i] = val - matrix *= 1 / np.sqrt(radix) + matrix *= 1 / sqrt(radix) self._utry = UnitaryMatrix(matrix, self.radixes) diff --git a/bqskit/ir/gates/parameterized/__init__.py b/bqskit/ir/gates/parameterized/__init__.py index 3ee2c645c..546520b60 100644 --- a/bqskit/ir/gates/parameterized/__init__.py +++ b/bqskit/ir/gates/parameterized/__init__.py @@ -12,6 +12,7 @@ from bqskit.ir.gates.parameterized.cu import CUGate from bqskit.ir.gates.parameterized.fsim import FSIMGate from bqskit.ir.gates.parameterized.pauli import PauliGate +from bqskit.ir.gates.parameterized.pauliz import PauliZGate from bqskit.ir.gates.parameterized.phasedxz import PhasedXZGate from bqskit.ir.gates.parameterized.rsu3 import RSU3Gate from bqskit.ir.gates.parameterized.rx import RXGate @@ -41,6 +42,7 @@ 'CUGate', 'FSIMGate', 'PauliGate', + 'PauliZGate', 'PhasedXZGate', 'RSU3Gate', 'RXGate', diff --git a/bqskit/ir/gates/parameterized/pauliz.py b/bqskit/ir/gates/parameterized/pauliz.py new file mode 100644 index 000000000..4eb42d6d7 --- /dev/null +++ b/bqskit/ir/gates/parameterized/pauliz.py @@ -0,0 +1,103 @@ +"""This module implements the PauliZGate.""" +from __future__ import annotations + +from typing import Any + +import numpy as np +import numpy.typing as npt + +from bqskit.ir.gates.generalgate import GeneralGate +from bqskit.ir.gates.qubitgate import QubitGate +from bqskit.qis.pauliz import PauliZMatrices +from bqskit.qis.unitary.differentiable import DifferentiableUnitary +from bqskit.qis.unitary.unitary import RealVector +from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix +from bqskit.utils.docs import building_docs +from bqskit.utils.math import dexpmv +from bqskit.utils.math import dot_product +from bqskit.utils.math import pauliz_expansion +from bqskit.utils.math import unitary_log_no_i + + +class PauliZGate(QubitGate, DifferentiableUnitary, GeneralGate): + """ + A gate representing an arbitrary diagonal rotation. + + This gate is given by: + + .. math:: + + \\exp({i(\\vec{\\alpha} \\cdot \\vec{\\sigma_Z^{\\otimes n}})}) + + Where :math:`\\vec{\\alpha}` are the gate's parameters, + :math:`\\vec{\\sigma}` are the PauliZ Z matrices, + and :math:`n` is the number of qubits this gate acts on. + """ + + def __init__(self, num_qudits: int) -> None: + """ + Create a PauliZGate acting on `num_qudits` qubits. + + Args: + num_qudits (int): The number of qudits this gate will act on. + + Raises: + ValueError: If `num_qudits` is nonpositive. + """ + + if num_qudits <= 0: + raise ValueError(f'Expected positive integer, got {num_qudits}') + + self._name = f'PauliZGate({num_qudits})' + self._num_qudits = num_qudits + paulizs = PauliZMatrices(self.num_qudits) + self._num_params = len(paulizs) + if building_docs(): + self.sigmav: npt.NDArray[Any] = np.array([]) + else: + self.sigmav = (-1j / 2) * paulizs.numpy + + def get_unitary(self, params: RealVector = []) -> UnitaryMatrix: + """Return the unitary for this gate, see :class:`Unitary` for more.""" + self.check_parameters(params) + H = dot_product(params, self.sigmav) + eiH = np.diag(np.exp(np.diag(H))) + return UnitaryMatrix(eiH, check_arguments=False) + + def get_grad(self, params: RealVector = []) -> npt.NDArray[np.complex128]: + """ + Return the gradient for this gate. + + See :class:`DifferentiableUnitary` for more info. + + TODO: Accelerated gradient computation for diagonal matrices. + """ + self.check_parameters(params) + H = dot_product(params, self.sigmav) + _, dU = dexpmv(H, self.sigmav) + return dU + + def get_unitary_and_grad( + self, + params: RealVector = [], + ) -> tuple[UnitaryMatrix, npt.NDArray[np.complex128]]: + """ + Return the unitary and gradient for this gate. + + See :class:`DifferentiableUnitary` for more info. + """ + self.check_parameters(params) + + H = dot_product(params, self.sigmav) + U, dU = dexpmv(H, self.sigmav) + return UnitaryMatrix(U, check_arguments=False), dU + + def calc_params(self, utry: UnitaryMatrix) -> list[float]: + """Return the parameters for this gate to implement `utry`""" + return list(-2 * pauliz_expansion(unitary_log_no_i(utry.numpy))) + + def __eq__(self, o: object) -> bool: + return isinstance(o, PauliZGate) and self.num_qudits == o.num_qudits + + def __hash__(self) -> int: + return hash((self.__class__.__name__, self.num_qudits)) diff --git a/bqskit/ir/interval.py b/bqskit/ir/interval.py index 5375955a6..97f8a6552 100644 --- a/bqskit/ir/interval.py +++ b/bqskit/ir/interval.py @@ -89,7 +89,7 @@ def __new__( 'Expected positive integers, got {lower} and {upper}.', ) - return super().__new__(cls, (lower, upper)) # type: ignore + return super().__new__(cls, (lower, upper)) @property def lower(self) -> int: diff --git a/bqskit/ir/iterator.py b/bqskit/ir/iterator.py index 87e9a3bd2..fa1813c69 100644 --- a/bqskit/ir/iterator.py +++ b/bqskit/ir/iterator.py @@ -248,18 +248,18 @@ def __init__( self.min_cycle = self.region.min_cycle self.max_cycle = self.region.max_cycle - if start < (self.min_cycle, self.min_qudit): - start = CircuitPoint(self.min_cycle, self.min_qudit) + if self.start < (self.min_cycle, self.min_qudit): + self.start = CircuitPoint(self.min_cycle, self.min_qudit) - if end > (self.max_cycle, self.max_qudit): - end = CircuitPoint(self.max_cycle, self.max_qudit) + if self.end > (self.max_cycle, self.max_qudit): + self.end = CircuitPoint(self.max_cycle, self.max_qudit) - assert isinstance(start, CircuitPoint) # TODO: Typeguard - assert isinstance(end, CircuitPoint) # TODO: Typeguard + assert isinstance(self.start, CircuitPoint) # TODO: Typeguard + assert isinstance(self.end, CircuitPoint) # TODO: Typeguard # Pointer into the circuit structure - self.cycle = start.cycle if not self.reverse else end.cycle - self.qudit = start.qudit if not self.reverse else end.qudit + self.cycle = self.start.cycle if not self.reverse else self.end.cycle + self.qudit = self.start.qudit if not self.reverse else self.end.qudit # Used to track changes to circuit structure self.num_ops = self.circuit.num_operations @@ -330,6 +330,8 @@ def __next__(self) -> Operation | tuple[int, Operation]: self.qudits_to_skip.add(self.qudit) continue + self.qudits_to_skip.update(op.location) + if self.exclude: if not all(qudit in self.qudits for qudit in op.location): continue @@ -340,8 +342,6 @@ def __next__(self) -> Operation | tuple[int, Operation]: ): continue - self.qudits_to_skip.update(op.location) - if self.and_cycles: return self.cycle, op diff --git a/bqskit/ir/point.py b/bqskit/ir/point.py index a44d1b1f5..0e510c449 100644 --- a/bqskit/ir/point.py +++ b/bqskit/ir/point.py @@ -66,7 +66,7 @@ def __new__( else: raise TypeError('Expected two integer arguments.') - return super().__new__(cls, (cycle, qudit)) # type: ignore + return super().__new__(cls, (cycle, qudit)) @property def cycle(self) -> int: diff --git a/bqskit/ir/region.py b/bqskit/ir/region.py index 78197e973..0d5cf9432 100644 --- a/bqskit/ir/region.py +++ b/bqskit/ir/region.py @@ -187,6 +187,11 @@ def empty(self) -> bool: """Return true if this region is empty.""" return len(self) == 0 + @property + def num_qudits(self) -> int: + """Return the number of qudits in this region.""" + return len(self) + def shift_left(self, amount_to_shift: int) -> CircuitRegion: """ Shift the region to the left by `amount_to_shift`. @@ -292,6 +297,10 @@ def overlaps(self, other: CircuitPointLike | CircuitRegionLike) -> bool: % type(other), ) + def copy(self) -> CircuitRegion: + """Return a deep copy of this region.""" + return CircuitRegion(self._intervals) + def __contains__(self, other: object) -> bool: if is_integer(other): return other in self._intervals.keys() diff --git a/bqskit/passes/__init__.py b/bqskit/passes/__init__.py index c7d85a504..b05f53246 100644 --- a/bqskit/passes/__init__.py +++ b/bqskit/passes/__init__.py @@ -33,6 +33,7 @@ QFASTDecompositionPass QPredictDecompositionPass SynthesisPass + WalshDiagonalSynthesisPass .. rubric:: Processing Passes @@ -43,6 +44,7 @@ ExhaustiveGateRemovalPass IterativeScanningGateRemovalPass ScanningGateRemovalPass + TreeScanningGateRemovalPass SubstitutePass .. rubric:: Retargeting Passes @@ -136,6 +138,10 @@ These passes either perform upper-bound error analysis of the PAM process. +.. autosummary:: + :toctree: autogen + :recursive: + TagPAMBlockDataPass CalculatePAMErrorsPass UnTagPAMBlockDataPass @@ -193,6 +199,7 @@ :toctree: autogen :recursive: + DiscreteLayerGenerator FourParamGenerator MiddleOutLayerGenerator SeedLayerGenerator @@ -258,6 +265,7 @@ from bqskit.passes.processing.iterative import IterativeScanningGateRemovalPass from bqskit.passes.processing.scan import ScanningGateRemovalPass from bqskit.passes.processing.substitute import SubstitutePass +from bqskit.passes.processing.treescan import TreeScanningGateRemovalPass from bqskit.passes.retarget.auto import AutoRebase2QuditGatePass from bqskit.passes.retarget.general import GeneralSQDecomposition from bqskit.passes.retarget.two import Rebase2QuditGatePass @@ -271,6 +279,7 @@ from bqskit.passes.rules.zxzxz import ZXZXZDecomposition from bqskit.passes.search.frontier import Frontier from bqskit.passes.search.generator import LayerGenerator +from bqskit.passes.search.generators.discrete import DiscreteLayerGenerator from bqskit.passes.search.generators.fourparam import FourParamGenerator from bqskit.passes.search.generators.middleout import MiddleOutLayerGenerator from bqskit.passes.search.generators.seed import SeedLayerGenerator @@ -282,6 +291,7 @@ from bqskit.passes.search.heuristics.astar import AStarHeuristic from bqskit.passes.search.heuristics.dijkstra import DijkstraHeuristic from bqskit.passes.search.heuristics.greedy import GreedyHeuristic +from bqskit.passes.synthesis.diagonal import WalshDiagonalSynthesisPass from bqskit.passes.synthesis.leap import LEAPSynthesisPass from bqskit.passes.synthesis.pas import PermutationAwareSynthesisPass from bqskit.passes.synthesis.qfast import QFASTDecompositionPass @@ -319,6 +329,7 @@ 'ScanPartitioner', 'QuickPartitioner', 'SynthesisPass', + 'WalshDiagonalSynthesisPass', 'LEAPSynthesisPass', 'QSearchSynthesisPass', 'QFASTDecompositionPass', @@ -330,6 +341,8 @@ 'UpdateDataPass', 'ToU3Pass', 'ScanningGateRemovalPass', + 'TreeScanningGateRemovalPass', + 'DiscreteLayerGenerator', 'SimpleLayerGenerator', 'AStarHeuristic', 'GreedyHeuristic', diff --git a/bqskit/passes/control/paralleldo.py b/bqskit/passes/control/paralleldo.py index 42b9bbeee..a95168058 100644 --- a/bqskit/passes/control/paralleldo.py +++ b/bqskit/passes/control/paralleldo.py @@ -34,7 +34,7 @@ def __init__( self, pass_sequences: Iterable[WorkflowLike], less_than: Callable[[Circuit, Circuit], bool], - pick_fisrt: bool = False, + pick_first: bool = False, ) -> None: """ Construct a ParallelDo. @@ -63,7 +63,7 @@ def __init__( self.workflows = [Workflow(p) for p in pass_sequences] self.less_than = less_than - self.pick_first = pick_fisrt + self.pick_first = pick_first if len(self.workflows) == 0: raise ValueError('Must specify at least one workflow.') diff --git a/bqskit/passes/control/predicates/diagonal.py b/bqskit/passes/control/predicates/diagonal.py new file mode 100644 index 000000000..0d72ca7f4 --- /dev/null +++ b/bqskit/passes/control/predicates/diagonal.py @@ -0,0 +1,39 @@ +"""This module implements the DiagonalPredicate class.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from bqskit.passes.control.predicate import PassPredicate +from bqskit.utils.math import diagonal_distance + +if TYPE_CHECKING: + from bqskit.compiler.passdata import PassData + from bqskit.ir.circuit import Circuit + + +class DiagonalPredicate(PassPredicate): + """ + The DiagonalPredicate class. + + The DiagonalPredicate class returns True if the circuit's unitary can be + approximately inverted by a diagonal unitary. A unitary is approximately + inverted when the Hilbert-Schmidt distance to the identity is less than some + threshold. + """ + + def __init__(self, threshold: float) -> None: + """ + Construct a DiagonalPredicate. + + Args: + threshold (float): If a circuit can be approximately inverted + by a diagonal unitary (meaning the Hilbert-Schmidt distance + to the identity is less than or equal to this number after + multiplying by the diagonal unitary), True is returned. + """ + self.threshold = threshold + + def get_truth_value(self, circuit: Circuit, data: PassData) -> bool: + """Call this predicate, see :class:`PassPredicate` for more info.""" + dist = diagonal_distance(circuit.get_unitary().numpy) + return dist <= self.threshold diff --git a/bqskit/passes/control/predicates/distributed.py b/bqskit/passes/control/predicates/distributed.py new file mode 100644 index 000000000..a1ceb2275 --- /dev/null +++ b/bqskit/passes/control/predicates/distributed.py @@ -0,0 +1,26 @@ +"""This module implements the DistributedPredicate class.""" +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from bqskit.passes.control.predicate import PassPredicate + +if TYPE_CHECKING: + from bqskit.compiler.passdata import PassData + from bqskit.ir.circuit import Circuit + +_logger = logging.getLogger(__name__) + + +class DistributedPredicate(PassPredicate): + """ + The DistributedPredicate class. + + The DistributedPredicate returns true if the targeted machine is distributed + across multiple chips. + """ + + def get_truth_value(self, circuit: Circuit, data: PassData) -> bool: + """Call this predicate, see :class:`PassPredicate` for more info.""" + return data.model.coupling_graph.is_distributed() diff --git a/bqskit/passes/mapping/pam.py b/bqskit/passes/mapping/pam.py index 549ab2b6b..92127a2c3 100644 --- a/bqskit/passes/mapping/pam.py +++ b/bqskit/passes/mapping/pam.py @@ -277,7 +277,7 @@ def _get_best_perm( cg: CouplingGraph, F: set[CircuitPoint], pi: list[int], - D: list[list[int]], + D: list[list[float]], E: set[CircuitPoint], qudits: Sequence[int], ) -> tuple[tuple[int, ...], Circuit, tuple[int, ...]]: @@ -366,7 +366,7 @@ def _score_perm( circuit: Circuit, F: set[CircuitPoint], pi: list[int], - D: list[list[int]], + D: list[list[float]], perm: tuple[Sequence[int], Sequence[int]], E: set[CircuitPoint], ) -> float: diff --git a/bqskit/passes/mapping/sabre.py b/bqskit/passes/mapping/sabre.py index b257fda84..27d52c0da 100644 --- a/bqskit/passes/mapping/sabre.py +++ b/bqskit/passes/mapping/sabre.py @@ -363,7 +363,7 @@ def _get_best_swap( circuit: Circuit, F: set[CircuitPoint], E: set[CircuitPoint], - D: list[list[int]], + D: list[list[float]], cg: CouplingGraph, pi: list[int], decay: list[float], @@ -416,7 +416,7 @@ def _score_swap( circuit: Circuit, F: set[CircuitPoint], pi: list[int], - D: list[list[int]], + D: list[list[float]], swap: tuple[int, int], decay: list[float], E: set[CircuitPoint], @@ -475,7 +475,7 @@ def _get_distance( self, logical_qudits: Sequence[int], pi: list[int], - D: list[list[int]], + D: list[list[float]], ) -> float: """Calculate the expected number of swaps to connect logical qudits.""" min_term = np.inf @@ -493,7 +493,7 @@ def _uphill_swaps( logical_qudits: Sequence[int], cg: CouplingGraph, pi: list[int], - D: list[list[int]], + D: list[list[float]], ) -> Iterator[tuple[int, int]]: """Yield the swaps necessary to bring some of the qudits together.""" center_qudit = min( diff --git a/bqskit/passes/processing/__init__.py b/bqskit/passes/processing/__init__.py index e54966b93..675eb2268 100644 --- a/bqskit/passes/processing/__init__.py +++ b/bqskit/passes/processing/__init__.py @@ -5,10 +5,12 @@ from bqskit.passes.processing.iterative import IterativeScanningGateRemovalPass from bqskit.passes.processing.scan import ScanningGateRemovalPass from bqskit.passes.processing.substitute import SubstitutePass +from bqskit.passes.processing.treescan import TreeScanningGateRemovalPass __all__ = [ 'ExhaustiveGateRemovalPass', 'IterativeScanningGateRemovalPass', 'ScanningGateRemovalPass', 'SubstitutePass', + 'TreeScanningGateRemovalPass', ] diff --git a/bqskit/passes/processing/treescan.py b/bqskit/passes/processing/treescan.py new file mode 100644 index 000000000..12376e59d --- /dev/null +++ b/bqskit/passes/processing/treescan.py @@ -0,0 +1,237 @@ +"""This module implements the TreeScanningGateRemovalPass.""" +from __future__ import annotations + +import logging +from typing import Any +from typing import Callable + +from bqskit.compiler.basepass import BasePass +from bqskit.compiler.passdata import PassData +from bqskit.ir.circuit import Circuit +from bqskit.ir.operation import Operation +from bqskit.ir.opt.cost.functions import HilbertSchmidtResidualsGenerator +from bqskit.ir.opt.cost.generator import CostFunctionGenerator +from bqskit.runtime import get_runtime +from bqskit.utils.typing import is_integer +from bqskit.utils.typing import is_real_number + +_logger = logging.getLogger(__name__) + + +class TreeScanningGateRemovalPass(BasePass): + """ + The TreeScanningGateRemovalPass class. + + Starting from one side of the circuit, run the following: + + Split the circuit operations into chunks of size `tree_depth` + At every iteration: + a. Look at the next chunk of operations + b. Generate 2 ^ `tree_depth` circuits. Each circuit corresponds to every + combination of whether or not to include one of the operations in the chunk. + c. Instantiate in parallel all 2^`tree_depth` circuits + d. Choose the circuit that has the least number of operations and move + on to the next chunk of operations. + + This optimization is less greedy than the current + :class:`~bqskit.passes.processing.ScanningGateRemovalPass` removal, + which leads to much better quality circuits than ScanningGate. + In very rare occasions, ScanningGate may be able to outperform + TreeScan (since it is still greedy), but in general we can expect + TreeScan to almost always outperform ScanningGate. + """ + + def __init__( + self, + start_from_left: bool = True, + success_threshold: float = 1e-8, + cost: CostFunctionGenerator = HilbertSchmidtResidualsGenerator(), + instantiate_options: dict[str, Any] = {}, + tree_depth: int = 1, + collection_filter: Callable[[Operation], bool] | None = None, + ) -> None: + """ + Construct a TreeScanningGateRemovalPass. + + Args: + start_from_left (bool): Determines where the scan starts + attempting to remove gates from. If True, scan goes left + to right, otherwise right to left. (Default: True) + + success_threshold (float): The distance threshold that + determines successful termintation. Measured in cost + described by the hilbert schmidt cost function. + (Default: 1e-8) + + cost (CostFunction | None): The cost function that determines + successful removal of a gate. + (Default: HilbertSchmidtResidualsGenerator()) + + instantiate_options (dict[str: Any]): Options passed directly + to circuit.instantiate when instantiating circuit + templates. (Default: {}) + + tree_depth (int): The depth of the tree of potential + solutions to instantiate. Note that 2^(tree_depth) - 1 + circuits will be instantiated in parallel. Note that the default + behavior will be equivalent to normal ScanningGateRemoval + (Default: 1) + + collection_filter (Callable[[Operation], bool] | None): + A predicate that determines which operations should be + attempted to be removed. Called with each operation + in the circuit. If this returns true, this pass will + attempt to remove that operation. Defaults to all + operations. + """ + + if not is_real_number(success_threshold): + raise TypeError( + 'Expected real number for success_threshold' + ', got %s' % type(success_threshold), + ) + + if not isinstance(cost, CostFunctionGenerator): + raise TypeError( + 'Expected cost to be a CostFunctionGenerator, got %s' + % type(cost), + ) + + if not isinstance(instantiate_options, dict): + raise TypeError( + 'Expected dictionary for instantiate_options, got %s.' + % type(instantiate_options), + ) + + self.collection_filter = collection_filter or default_collection_filter + + if not callable(self.collection_filter): + raise TypeError( + 'Expected callable method that maps Operations to booleans for' + ' collection_filter, got %s.' % type(self.collection_filter), + ) + + if not is_integer(tree_depth): + raise TypeError( + 'Expected Integer type for tree_depth, got %s.' + % type(instantiate_options), + ) + + self.tree_depth = tree_depth + self.start_from_left = start_from_left + self.success_threshold = success_threshold + self.cost = cost + self.instantiate_options: dict[str, Any] = { + 'dist_tol': self.success_threshold, + 'min_iters': 10, + 'cost_fn_gen': self.cost, + } + self.instantiate_options.update(instantiate_options) + + @staticmethod + def get_tree_circs( + orig_num_cycles: int, + circuit_copy: Circuit, + cycle_and_ops: list[tuple[int, Operation]], + ) -> list[Circuit]: + """ + Generate all circuits to be instantiated in the tree scan. + + Args: + orig_num_cycles (int): The original number of cycles + in the circuit. This allows us to keep track of the shift + caused by previous deletions. + + circuit_copy (Circuit): Current state of the circuit. + + cycle_and_ops: list[(int, Operation)]: The next chunk + of operations to be considered for deletion. + + Returns: + list[Circuit]: A list of 2^(`tree_depth`) - 1 circuits + that remove up to `tree_depth` operations. The circuits + are sorted by the number of operations removed. + """ + all_circs = [circuit_copy.copy()] + for cycle, op in cycle_and_ops: + new_circs = [] + for circ in all_circs: + idx_shift = orig_num_cycles - circ.num_cycles + new_cycle = cycle - idx_shift + work_copy = circ.copy() + work_copy.pop((new_cycle, op.location[0])) + new_circs.append(work_copy) + new_circs.append(circ) + + all_circs = new_circs + + all_circs = sorted(all_circs, key=lambda x: x.num_operations) + # Remove circuit with no gates deleted + return all_circs[:-1] + + async def run(self, circuit: Circuit, data: PassData) -> None: + """Perform the pass's operation, see :class:`BasePass` for more.""" + instantiate_options = self.instantiate_options.copy() + if 'seed' not in instantiate_options: + instantiate_options['seed'] = data.seed + + start = 'left' if self.start_from_left else 'right' + _logger.debug(f'Starting tree scanning gate removal on the {start}.') + + target = self.get_target(circuit, data) + + circuit_copy = circuit.copy() + reverse_iter = not self.start_from_left + + ops_left = list(circuit.operations_with_cycles(reverse=reverse_iter)) + print( + f'Starting TreeScan with tree depth {self.tree_depth}' + f' on circuit with {len(ops_left)} gates', + ) + + while ops_left: + chunk = ops_left[:self.tree_depth] + ops_left = ops_left[self.tree_depth:] + + all_circs = TreeScanningGateRemovalPass.get_tree_circs( + circuit.num_cycles, circuit_copy, chunk, + ) + + _logger.debug( + 'Attempting removal of operation of up to' + f' {self.tree_depth} operations.', + ) + + instantiated_circuits: list[Circuit] = await get_runtime().map( + Circuit.instantiate, + all_circs, + target=target, + **instantiate_options, + ) + + dists = [self.cost(c, target) for c in instantiated_circuits] + + # Pick least count with least dist + for i, dist in enumerate(dists): + if dist < self.success_threshold: + # Log gates removed + gate_dict_orig = circuit_copy.gate_counts + gate_dict_new = instantiated_circuits[i].gate_counts + gates_removed = { + k: circuit_copy.gate_counts[k] - gate_dict_new.get(k, 0) + for k in gate_dict_orig.keys() + } + gates_removed = { + k: v for k, v in gates_removed.items() if v != 0 + } + _logger.debug( + f'Successfully removed {gates_removed} gates', + ) + circuit_copy = instantiated_circuits[i] + break + + circuit.become(circuit_copy) + + +def default_collection_filter(op: Operation) -> bool: + return True diff --git a/bqskit/passes/search/generators/__init__.py b/bqskit/passes/search/generators/__init__.py index aec7fab2e..dae5ac695 100644 --- a/bqskit/passes/search/generators/__init__.py +++ b/bqskit/passes/search/generators/__init__.py @@ -1,6 +1,7 @@ """This package contains LayerGenerator definitions.""" from __future__ import annotations +from bqskit.passes.search.generators.discrete import DiscreteLayerGenerator from bqskit.passes.search.generators.fourparam import FourParamGenerator from bqskit.passes.search.generators.middleout import MiddleOutLayerGenerator from bqskit.passes.search.generators.seed import SeedLayerGenerator @@ -10,6 +11,7 @@ from bqskit.passes.search.generators.wide import WideLayerGenerator __all__ = [ + 'DiscreteLayerGenerator', 'FourParamGenerator', 'MiddleOutLayerGenerator', 'SeedLayerGenerator', @@ -17,4 +19,5 @@ 'SingleQuditLayerGenerator', 'StairLayerGenerator', 'WideLayerGenerator', + 'DiscreteLayerGenerator', ] diff --git a/bqskit/passes/search/generators/discrete.py b/bqskit/passes/search/generators/discrete.py new file mode 100644 index 000000000..33ae4d0d9 --- /dev/null +++ b/bqskit/passes/search/generators/discrete.py @@ -0,0 +1,225 @@ +"""This module implements the DiscreteLayerGenerator class.""" +from __future__ import annotations + +import logging +from typing import Callable +from typing import Sequence + +from bqskit.compiler.passdata import PassData +from bqskit.ir.circuit import Circuit +from bqskit.ir.gate import Gate +from bqskit.ir.gates.constant.cx import CNOTGate +from bqskit.ir.gates.constant.h import HGate +from bqskit.ir.gates.constant.t import TGate +from bqskit.ir.gates.parameterized.pauliz import PauliZGate +from bqskit.ir.operation import Operation +from bqskit.passes.search.generator import LayerGenerator +from bqskit.qis.state.state import StateVector +from bqskit.qis.state.system import StateSystem +from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix +from bqskit.utils.typing import is_sequence + + +_logger = logging.getLogger(__name__) + + +class DiscreteLayerGenerator(LayerGenerator): + """ + The DiscreteLayerGenerator class. + + Expands circuits using only discrete gates. This is a non-reinforcement + learning version of diagonalizing in + https://arxiv.org/abs/2409.00433. + """ + + def __init__( + self, + gateset: Sequence[Gate] = [HGate(), TGate(), CNOTGate()], + double_headed: bool = False, + dividing_gate_type: Callable[[int], Gate] = PauliZGate, + ) -> None: + """ + Construct a DiscreteLayerGenerator. + + Args: + gateset (Sequence[Gate]): A sequence of gates that can be used + in the output circuit. These must be non-parameterized gates. + (Default: [HGate, TGate, CNOTGate]) + + double_headed (bool): If True, successors will be generated by + both appending and prepending gates. This lets unitaries be + diagonalized instead of inverted. (Default: False) + + dividing_gate (Callable[[int], Gate]): A gate that goes between + the two heads of the discrete searches. If double_headed is + False, this gate simply goes at the beggining of the circuit. + (Default: PauliZGate) + + Raises: + ValueError: If the gateset is not a sequence. + + ValueError: If the gateset contains a parameterized gate. + + ValueError: If the radices of gates are different. + + TODO: + Check universality of gateset. + """ + if not is_sequence(gateset): + m = f'Expected sequence of gates, got {type(gateset)}.' + raise ValueError(m) + + radix = gateset[0].radixes[0] + for gate in gateset: + if gate.num_params > 0: + m = 'Expected gate for constant gates, got parameterized' + m += f' {gate} gate.' + raise ValueError(m) + for rad in gate.radixes: + if rad != radix: + m = f'Radix mismatch on gate: {gate}. ' + m += f'Expected {radix}, got {rad}.' + raise ValueError(m) + self.gateset = gateset + self.double_headed = double_headed + self.dividing_gate_type = dividing_gate_type + + def gen_initial_layer( + self, + target: UnitaryMatrix | StateVector | StateSystem, + data: PassData, + ) -> Circuit: + """ + Generate the initial layer, see LayerGenerator for more. + + Raises: + ValueError: If `target` has a radix mismatch with + `self.initial_layer_gate`. + """ + + if not isinstance(target, (UnitaryMatrix, StateVector, StateSystem)): + m = f'Expected unitary or state, got {type(target)}.' + raise TypeError(m) + + for radix in target.radixes: + if radix != self.gateset[0].radixes[0]: + m = 'Radix mismatch between target and gateset.' + raise ValueError(m) + + init_circuit = Circuit(target.num_qudits, target.radixes) + + if self.double_headed: + n = target.num_qudits + span = list(range(n)) + init_circuit.append_gate(self.dividing_gate_type(n), span) + + return init_circuit + + def cancels_something( + self, + circuit: Circuit, + gate: Gate, + location: tuple[int, ...], + ) -> bool: + """Ensure applying gate at location does not cancel a previous gate.""" + last_cycle = circuit.num_cycles - 1 + try: + op = circuit.get_operation((last_cycle, location[0])) + op_gate, op_location = op.gate, op.location + if op_location == location and op_gate.get_inverse() == gate: + return True + return False + except IndexError: + return False + + def count_repeats( + self, + circuit: Circuit, + gate: Gate, + qudit: int, + ) -> int: + """Count the number of times the last gate is repeated on qudit.""" + count = 0 + for cycle in reversed(range(circuit.num_cycles)): + try: + op = circuit.get_operation((cycle, qudit)) + if op.gate == gate: + count += 1 + else: + return count + except IndexError: + continue + return count + + def gen_successors(self, circuit: Circuit, data: PassData) -> list[Circuit]: + """ + Generate the successors of a circuit node. + + Raises: + ValueError: If circuit is a single-qudit circuit. + """ + if not isinstance(circuit, Circuit): + raise TypeError(f'Expected circuit, got {type(circuit)}.') + + if circuit.num_qudits < 2: + raise ValueError('Cannot expand a single-qudit circuit.') + + # Get the coupling graph + coupling_graph = data.connectivity + + # Generate successors + successors = [] + hashes = set() + singles = [gate for gate in self.gateset if gate.num_qudits == 1] + multis = [gate for gate in self.gateset if gate.num_qudits > 1] + + def add_to_successors(circuit: Circuit) -> None: + h = self.hash_circuit_structure(circuit) + if h not in hashes: + successors.append(circuit) + hashes.add(h) + + for gate in singles: + for qudit in range(circuit.num_qudits): + if gate.radixes[0] != circuit.radixes[qudit]: + continue + if self.cancels_something(circuit, gate, (qudit,)): + continue + if isinstance(gate, TGate): + if self.count_repeats(circuit, TGate(), qudit) >= 7: + continue + successor = circuit.copy() + successor.append_gate(gate, [qudit]) + + add_to_successors(successor) + + if self.double_headed: + successor = circuit.copy() + op = Operation(gate, [qudit]) + successor.insert(0, op) + add_to_successors(successor) + + for gate in multis: + for edge in coupling_graph: + if self.cancels_something(circuit, gate, edge): + continue + qudit_radixes = [circuit.radixes[q] for q in edge] + if gate.radixes != qudit_radixes: + continue + successor = circuit.copy() + successor.append_gate(gate, edge) + add_to_successors(successor) + + if self.double_headed: + successor = circuit.copy() + op = Operation(gate, edge) + successor.insert(0, op) + add_to_successors(successor) + + return successors + + def hash_circuit_structure(self, circuit: Circuit) -> int: + hashes = [] + for op in circuit: + hashes.append(hash(op)) + return hash(tuple(hashes)) diff --git a/bqskit/passes/synthesis/__init__.py b/bqskit/passes/synthesis/__init__.py index 0d4e49c6b..eea097231 100644 --- a/bqskit/passes/synthesis/__init__.py +++ b/bqskit/passes/synthesis/__init__.py @@ -1,6 +1,7 @@ """This package implements synthesis passes and synthesis related classes.""" from __future__ import annotations +from bqskit.passes.synthesis.diagonal import WalshDiagonalSynthesisPass from bqskit.passes.synthesis.leap import LEAPSynthesisPass from bqskit.passes.synthesis.pas import PermutationAwareSynthesisPass from bqskit.passes.synthesis.qfast import QFASTDecompositionPass @@ -17,4 +18,5 @@ 'SynthesisPass', 'SetTargetPass', 'PermutationAwareSynthesisPass', + 'WalshDiagonalSynthesisPass', ] diff --git a/bqskit/passes/synthesis/diagonal.py b/bqskit/passes/synthesis/diagonal.py new file mode 100644 index 000000000..969703a5f --- /dev/null +++ b/bqskit/passes/synthesis/diagonal.py @@ -0,0 +1,112 @@ +"""This module implements the WalshDiagonalSynthesisPass.""" +from __future__ import annotations + +import logging + +from numpy import where + +from bqskit.compiler.passdata import PassData +from bqskit.ir.circuit import Circuit +from bqskit.ir.gates import CNOTGate +from bqskit.ir.gates import RZGate +from bqskit.passes.synthesis.synthesis import SynthesisPass +from bqskit.qis.state.state import StateVector +from bqskit.qis.state.system import StateSystem +from bqskit.qis.unitary import UnitaryMatrix +from bqskit.utils.math import pauliz_expansion +from bqskit.utils.math import unitary_log_no_i + + +_logger = logging.getLogger(__name__) + + +class WalshDiagonalSynthesisPass(SynthesisPass): + """ + A pass that synthesizes diagonal unitaries into Walsh functions. + + Based on: https://arxiv.org/abs/1306.3991 + """ + + def __init__( + self, + parameter_precision: float = 1e-8, + ) -> None: + """ + Constructor for WalshDiagonalSynthesisPass. + + Args: + parameter_precision (float): Pauli strings with parameter values + less than this are rounded to zero. (Default: 1e-8) + + TODO: + - Cancel adjacent CNOTs + - See how QFAST can be used to generalize to qudits + """ + self.parameter_precision = parameter_precision + + def gray_code(self, number: int) -> int: + """Convert a number to its Gray code representation.""" + gray = number ^ (number >> 1) + return gray + + def pauli_to_subcircuit( + self, + string_id: int, + angle: float, + num_qubits: int, + ) -> Circuit: + string = bin(string_id)[2:].zfill(num_qubits) + circuit = Circuit(num_qubits) + locations = [i for i in range(num_qubits) if string[i] == '1'] + if len(locations) == 1: + circuit.append_gate(RZGate(), locations[0], [angle]) + elif len(locations) > 1: + pairs = [ + (locations[i], locations[i + 1]) + for i in range(len(locations) - 1) + ] + for pair in pairs: + circuit.append_gate(CNOTGate(), pair) + circuit.append_gate(RZGate(), locations[-1], [angle]) + for pair in reversed(pairs): + circuit.append_gate(CNOTGate(), pair) + return circuit + + async def synthesize( + self, + utry: UnitaryMatrix | StateVector | StateSystem, + data: PassData, + ) -> Circuit: + """Synthesize `utry`, see :class:`SynthesisPass` for more.""" + if not isinstance(utry, UnitaryMatrix): + m = 'WalshDiagonalSynthesisPass can only synthesize diagonal, ' + m += f'`UnitaryMatrix`s, got {type(utry)}.' + raise TypeError(m) + + if not utry.is_qubit_only(): + m = 'WalshDiagonalSynthesisPass can only synthesize diagonal ' + m += '`UnitaryMatrix`s with qubits, got higher radix than 2.' + raise ValueError(m) + + num_qubits = utry.num_qudits + circuit = Circuit(num_qubits) + + # Find parameters of each I/Z Pauli string + H_matrix = unitary_log_no_i(utry.numpy) + params = pauliz_expansion(H_matrix) * 2 + # Remove low weight terms - these are likely numerical errors + params = where(abs(params) < self.parameter_precision, 0, params) + + # Order the Pauli strings by their Gray code representation + pauli_params = sorted( + [(i, -p) for i, p in enumerate(params)], + key=lambda x: self.gray_code(x[0]), + ) + subcircuits = [ + self.pauli_to_subcircuit(i, p, num_qubits) for i, p in pauli_params + ] + + for subcircuit in subcircuits: + circuit.append_circuit(subcircuit, [_ for _ in range(num_qubits)]) + + return circuit diff --git a/bqskit/passes/synthesis/leap.py b/bqskit/passes/synthesis/leap.py index da7475fae..f05300eef 100644 --- a/bqskit/passes/synthesis/leap.py +++ b/bqskit/passes/synthesis/leap.py @@ -196,7 +196,7 @@ async def synthesize( # Evalute initial layer if best_dist < self.success_threshold: - _logger.debug('Successful synthesis.') + _logger.debug('Successful synthesis with 0 layers.') return initial_layer # Main loop @@ -222,7 +222,9 @@ async def synthesize( dist = self.cost.calc_cost(circuit, utry) if dist < self.success_threshold: - _logger.debug('Successful synthesis.') + _logger.debug( + f'Successful synthesis with {layer + 1} layers.', + ) if self.store_partial_solutions: data['psols'] = psols return circuit diff --git a/bqskit/passes/synthesis/pas.py b/bqskit/passes/synthesis/pas.py index 2cad278d2..9a109f15a 100644 --- a/bqskit/passes/synthesis/pas.py +++ b/bqskit/passes/synthesis/pas.py @@ -113,6 +113,7 @@ async def synthesize( self.inner_synthesis.synthesize, targets, [data] * len(targets), + log_context=[{'perm': str(perm)} for perm in permsbyperms], ) # Return best circuit diff --git a/bqskit/passes/synthesis/qfast.py b/bqskit/passes/synthesis/qfast.py index 9a72bd79c..e4e036fb1 100644 --- a/bqskit/passes/synthesis/qfast.py +++ b/bqskit/passes/synthesis/qfast.py @@ -164,7 +164,7 @@ async def synthesize( if dist < self.success_threshold: self.finalize(circuit, utry, instantiate_options) - _logger.info('Successful synthesis.') + _logger.info(f'Successful synthesis with {depth} layers.') return circuit # Expand or restrict head diff --git a/bqskit/passes/synthesis/qsearch.py b/bqskit/passes/synthesis/qsearch.py index 13276ad82..9cad4fc44 100644 --- a/bqskit/passes/synthesis/qsearch.py +++ b/bqskit/passes/synthesis/qsearch.py @@ -171,7 +171,7 @@ async def synthesize( # Evalute initial layer if best_dist < self.success_threshold: - _logger.debug('Successful synthesis.') + _logger.debug('Successful synthesis with 0 layers.') return initial_layer # Main loop @@ -197,7 +197,9 @@ async def synthesize( dist = self.cost.calc_cost(circuit, utry) if dist < self.success_threshold: - _logger.debug('Successful synthesis.') + _logger.debug( + f'Successful synthesis with {layer + 1} layers.', + ) if self.store_partial_solutions: data['psols'] = psols return circuit @@ -210,7 +212,7 @@ async def synthesize( ) best_dist = dist best_circ = circuit - best_layer = layer + best_layer = layer + 1 if self.store_partial_solutions: if layer not in psols: diff --git a/bqskit/qis/graph.py b/bqskit/qis/graph.py index 06994ef34..54e0664ec 100644 --- a/bqskit/qis/graph.py +++ b/bqskit/qis/graph.py @@ -6,11 +6,10 @@ import logging from random import shuffle from typing import Any -from typing import cast from typing import Collection from typing import Iterable from typing import Iterator -from typing import List +from typing import Mapping from typing import Tuple from typing import TYPE_CHECKING from typing import Union @@ -23,7 +22,7 @@ from bqskit.ir.location import CircuitLocation from bqskit.ir.location import CircuitLocationLike from bqskit.utils.typing import is_integer -from bqskit.utils.typing import is_iterable +from bqskit.utils.typing import is_iterable, is_mapping, is_real_number _logger = logging.getLogger(__name__) @@ -33,31 +32,143 @@ class CouplingGraph(Collection[Tuple[int, int]]): def __init__( self, - graph: Iterable[tuple[int, int]], + graph: CouplingGraphLike, num_qudits: int | None = None, + remote_edges: Iterable[tuple[int, int]] = [], + default_weight: float = 1.0, + default_remote_weight: float = 100.0, + edge_weights_overrides: Mapping[tuple[int, int], float] = {}, ) -> None: + """ + Construct a new CouplingGraph. + + Args: + graph (CouplingGraphLike): The undirected graph edges. + + num_qudits (int | None): The number of qudits in the graph. If + None, the number of qudits is inferred from the maximum seen + in the edge list. (Default: None) + + remote_edges (Iterable[tuple[int, int]]): The edges that cross + QPU chip boundaries. Distributed QPUs will have remote links + connect them. Notes, remote edges must specified both in + `graph` and here. (Default: []) + + default_weight (float): The default weight of an edge in the + graph. (Default: 1.0) + + default_remote_weight (float): The default weight of a remote + edge in the graph. (Default: 100.0) + + edge_weights_overrides (Mapping[tuple[int, int], float]): A mapping + of edges to their weights. These override the defaults on + a case-by-case basis. (Default: {}) + + Raises: + ValueError: If `num_qudits` is too small for the edges in `graph`. + + ValueError: If `num_qudits` is less than zero. + + ValueError: If any edge in `remote_edges` is not in `graph`. + + ValueError: If any edge in `edge_weights_overrides` is not in + `graph`. + """ + if not CouplingGraph.is_valid_coupling_graph(graph): + raise TypeError('Invalid coupling graph.') + + if num_qudits is not None and not is_integer(num_qudits): + raise TypeError( + 'Expected integer for num_qudits,' + f' got {type(num_qudits)}', + ) + + if num_qudits is not None and num_qudits < 0: + raise ValueError( + 'Expected nonnegative num_qudits,' + f' got {num_qudits}.', + ) + + if not CouplingGraph.is_valid_coupling_graph(remote_edges): + raise TypeError('Invalid remote links.') + + if any(edge not in graph for edge in remote_edges): + invalids = [e for e in remote_edges if e not in graph] + raise ValueError( + f'Remote links {invalids} not in graph.' + ' All remote links must also be specified in the graph input.', + ) + + if not is_real_number(default_weight): + raise TypeError( + 'Expected integer for default_weight,' + f' got {type(default_weight)}', + ) + + if not is_real_number(default_remote_weight): + raise TypeError( + 'Expected integer for default_remote_weight,' + f' got {type(default_remote_weight)}', + ) + + if not is_mapping(edge_weights_overrides): + raise TypeError( + 'Expected mapping for edge_weights_overrides,' + f' got {type(edge_weights_overrides)}', + ) + + if any( + not is_real_number(v) + for v in edge_weights_overrides.values() + ): + invalids = [ + v for v in edge_weights_overrides.values() + if not is_real_number(v) + ] + raise TypeError( + 'Expected integer values for edge_weights_overrides,' + f' got non-integer values: {invalids}.', + ) + + if any(edge not in graph for edge in edge_weights_overrides): + invalids = [ + e for e in edge_weights_overrides + if e not in graph + ] + raise ValueError( + f'Edges {invalids} from edge_weights_overrides are not in ' + 'the graph. All edge_weights_overrides must also be ' + 'specified in the graph input.', + ) + if isinstance(graph, CouplingGraph): self.num_qudits: int = graph.num_qudits self._edges: set[tuple[int, int]] = graph._edges + self._remote_edges: set[tuple[int, int]] = graph._remote_edges self._adj: list[set[int]] = graph._adj + self._mat: list[list[float]] = graph._mat + self.default_weight: float = graph.default_weight + self.default_remote_weight: float = graph.default_remote_weight return - if not CouplingGraph.is_valid_coupling_graph(graph): - raise TypeError('Invalid coupling graph.') - - self._edges = {g if g[0] <= g[1] else (g[1], g[0]) for g in graph} + calc_num_qudits = 0 + for q1, q2 in graph: + calc_num_qudits = max(calc_num_qudits, max(q1, q2)) + calc_num_qudits += 1 - calced_num_qudits = 0 - for q1, q2 in self._edges: - calced_num_qudits = max(calced_num_qudits, max(q1, q2)) - calced_num_qudits += 1 + if num_qudits is not None and calc_num_qudits > num_qudits: + raise ValueError( + 'Edges between invalid qudits or num_qudits too small.', + ) - if num_qudits is None: - self.num_qudits = calced_num_qudits - elif calced_num_qudits > num_qudits: - raise ValueError('Edges between invalid qudits.') - else: - self.num_qudits = num_qudits + self.num_qudits = calc_num_qudits if num_qudits is None else num_qudits + self._edges = {g if g[0] <= g[1] else (g[1], g[0]) for g in graph} + self._remote_edges = { + e if e[0] <= e[1] else (e[1], e[0]) + for e in remote_edges + } + self.default_weight = default_weight + self.default_remote_weight = default_remote_weight self._adj = [set() for _ in range(self.num_qudits)] for q1, q2 in self._edges: @@ -69,8 +180,77 @@ def __init__( for _ in range(self.num_qudits) ] for q1, q2 in self._edges: - self._mat[q1][q2] = 1 - self._mat[q2][q1] = 1 + self._mat[q1][q2] = default_weight + self._mat[q2][q1] = default_weight + + for q1, q2 in self._remote_edges: + self._mat[q1][q2] = default_remote_weight + self._mat[q2][q1] = default_remote_weight + + for (q1, q2), weight in edge_weights_overrides.items(): + self._mat[q1][q2] = weight + self._mat[q2][q1] = weight + + def get_qpu_to_qudit_map(self) -> list[list[int]]: + """Return a mapping of QPU indices to qudit indices.""" + if not hasattr(self, '_qpu_to_qudit'): + seen = set() + self._qpu_to_qudit = [] + for qudit in range(self.num_qudits): + if qudit in seen: + continue + qpu = [] + frontier = {qudit} + while len(frontier) > 0: + node = frontier.pop() + qpu.append(node) + seen.add(node) + for neighbor in self._adj[node]: + if (node, neighbor) in self._remote_edges: + continue + if (neighbor, node) in self._remote_edges: + continue + if neighbor not in seen: + frontier.add(neighbor) + self._qpu_to_qudit.append(qpu) + return self._qpu_to_qudit + + def is_distributed(self) -> bool: + """Return true if the graph represents multiple connected QPUs.""" + return len(self._remote_edges) > 0 + + def qpu_count(self) -> int: + """Return the number of connected QPUs.""" + return len(self.get_qpu_to_qudit_map()) + + def get_individual_qpu_graphs(self) -> list[CouplingGraph]: + """Return a list of individual QPU graphs.""" + if not self.is_distributed(): + return [self] + + qpu_to_qudit = self.get_qpu_to_qudit_map() + return [self.get_subgraph(qpu) for qpu in qpu_to_qudit] + + def get_qudit_to_qpu_map(self) -> list[int]: + """Return a mapping of qudit indices to QPU indices.""" + qpu_to_qudit = self.get_qpu_to_qudit_map() + qudit_to_qpu = {} + for qpu, qudits in enumerate(qpu_to_qudit): + for qudit in qudits: + qudit_to_qpu[qudit] = qpu + return list(qudit_to_qpu.values()) + + def get_qpu_connectivity(self) -> list[set[int]]: + """Return the adjacency list of the QPUs.""" + qpu_to_qudit = self.get_qpu_to_qudit_map() + qudit_to_qpu = self.get_qudit_to_qpu_map() + qpu_adj: list[set[int]] = [set() for _ in range(len(qpu_to_qudit))] + for q1, q2 in self._remote_edges: + qpu1 = qudit_to_qpu[q1] + qpu2 = qudit_to_qpu[q2] + qpu_adj[qpu1].add(qpu2) + qpu_adj[qpu2].add(qpu1) + return qpu_adj def is_fully_connected(self) -> bool: """Return true if the graph is fully connected.""" @@ -92,6 +272,27 @@ def is_fully_connected(self) -> bool: return False + def is_linear(self) -> bool: + """Return true if the graph is linearly connected.""" + if self.num_qudits < 2: + return False + + num_deg_1 = 0 + for node_neighbors in self._adj: + if len(node_neighbors) == 1: + num_deg_1 += 1 + + elif len(node_neighbors) == 0: + return False + + elif len(node_neighbors) > 2: + return False + + if num_deg_1 != 2: + return False + + return True + def get_neighbors_of(self, qudit: int) -> list[int]: """Return the qudits adjacent to `qudit`.""" return list(self._adj[qudit]) @@ -129,12 +330,12 @@ def __repr__(self) -> str: def get_qudit_degrees(self) -> list[int]: return [len(l) for l in self._adj] - def all_pairs_shortest_path(self) -> list[list[int]]: + def all_pairs_shortest_path(self) -> list[list[float]]: """ Calculate all pairs shortest path matrix using Floyd-Warshall. Returns: - D (list[list[int]]): D[i][j] is the length of the shortest + D (list[list[float]]): D[i][j] is the length of the shortest path from i to j. """ D = copy.deepcopy(self._mat) @@ -142,7 +343,7 @@ def all_pairs_shortest_path(self) -> list[list[int]]: for i in range(self.num_qudits): for j in range(self.num_qudits): D[i][j] = min(D[i][j], D[i][k] + D[k][j]) - return cast(List[List[int]], D) + return D def get_shortest_path_tree(self, source: int) -> list[tuple[int, ...]]: """Return shortest path from `source` to every node in `self`.""" diff --git a/bqskit/qis/pauliz.py b/bqskit/qis/pauliz.py new file mode 100644 index 000000000..101324596 --- /dev/null +++ b/bqskit/qis/pauliz.py @@ -0,0 +1,268 @@ +"""This module implements the PauliZMatrices class.""" +from __future__ import annotations + +import itertools as it +from typing import Iterable +from typing import Iterator +from typing import overload +from typing import Sequence +from typing import TYPE_CHECKING + +import numpy as np +import numpy.typing as npt + +from bqskit.utils.typing import is_integer +from bqskit.utils.typing import is_numeric +from bqskit.utils.typing import is_sequence + +if TYPE_CHECKING: + from bqskit.qis.unitary.unitary import RealVector + + +class PauliZMatrices(Sequence[npt.NDArray[np.complex128]]): + """ + The group of Pauli Z matrices. + + A PauliZMatrices object represents the entire of set of diagonal Hermitian + matrices for some number of qubits. These matrices are a linear combination + of all n-fold tensor products of Pauli Z and the identity matrix. + + Examples: + .. math:: + I + Z = \\begin{pmatrix} + 2 & 0 \\\\ + 0 & 0 \\\\ + \\end{pmatrix} + + .. math:: + I \\otimes I + Z \\otimes I + 3 I \\otimes Z - Z \\otimes Z = + \\begin{pmatrix} + 4 & 0 & 0 & 0 \\\\ + 0 & -2 & 0 & 0 \\\\ + 0 & 0 & 2 & 0 \\\\ + 0 & 0 & 0 & -2 \\\\ + \\end{pmatrix} + """ + + Z = np.array( + [ + [1, 0], + [0, -1], + ], dtype=np.complex128, + ) + """The Pauli Z Matrix.""" + + I = np.array( + [ + [1, 0], + [0, 1], + ], dtype=np.complex128, + ) + """The Identity Matrix.""" + + def __init__(self, num_qudits: int) -> None: + """ + Construct the Pauli Z group for `num_qudits` number of qubits. + + Args: + num_qudits (int): Power of the tensor product of the Pauli Z + group. + + Raises: + ValueError: If `num_qudits` is less than or equal to 0. + """ + + if not is_integer(num_qudits): + raise TypeError( + 'Expected integer for num_qudits, got %s.' % + type(num_qudits), + ) + + if num_qudits <= 0: + raise ValueError( + 'Expected positive integer for num_qudits, got %s.' % type( + num_qudits, + ), + ) + + self.num_qudits = num_qudits + + if num_qudits == 1: + self.paulizs = [ + PauliZMatrices.I, + PauliZMatrices.Z, + ] + else: + self.paulizs = [] + matrices = it.product( + PauliZMatrices( + num_qudits - 1, + ), + PauliZMatrices(1), + ) + for pauliz_n_1, pauliz_1 in matrices: + self.paulizs.append(np.kron(pauliz_n_1, pauliz_1)) + + def __iter__(self) -> Iterator[npt.NDArray[np.complex128]]: + return self.paulizs.__iter__() + + @overload + def __getitem__(self, index: int) -> npt.NDArray[np.complex128]: + ... + + @overload + def __getitem__(self, index: slice) -> list[npt.NDArray[np.complex128]]: + ... + + def __getitem__( + self, + index: int | slice, + ) -> npt.NDArray[np.complex128] | list[npt.NDArray[np.complex128]]: + return self.paulizs[index] + + def __len__(self) -> int: + return len(self.paulizs) + + @property + def numpy(self) -> npt.NDArray[np.complex128]: + """The NumPy array holding the pauliz matrices.""" + return np.array(self.paulizs) + + def __array__( + self, + dtype: np.typing.DTypeLike = np.complex128, + ) -> npt.NDArray[np.complex128]: + """Implements NumPy API for the PauliZMatrices class.""" + if dtype != np.complex128: + raise ValueError('PauliZMatrices only supports Complex128 dtype.') + + return np.array(self.paulizs, dtype) + + def get_projection_matrices( + self, q_set: Iterable[int], + ) -> list[npt.NDArray[np.complex128]]: + """ + Return the Pauli Z matrices that act only on qubits in `q_set`. + + Args: + q_set (Iterable[int]): Active qubit indices + + Returns: + list[np.ndarray]: Pauli Z matrices from `self` acting only + on qubits in `q_set`. + + Raises: + ValueError: if `q_set` is an invalid set of qubit indices. + """ + q_set = list(q_set) + + if not all(is_integer(q) for q in q_set): + raise TypeError('Expected sequence of integers for qubit indices.') + + if any(q < 0 or q >= self.num_qudits for q in q_set): + raise ValueError('Qubit indices must be in [0, n).') + + if len(q_set) != len(set(q_set)): + raise ValueError('Qubit indices cannot have duplicates.') + + # Nth Order Pauli Z Matrices can be thought of base 2 number + # I = 0, Z = 1 + # IZZ = 1 * 2^2 + 1 * 2^1 + 0 * 2^0 = 6 (base 10) + # This gives the idx of IZZ in paulizs + # Note we read qubit index from the left, + # so Z in ZII corresponds to q = 0 + pauliz_n_qubit = [] + for ps in it.product([0, 1], repeat=len(q_set)): + idx = 0 + for p, q in zip(ps, q_set): + idx += p * (2 ** (self.num_qudits - q - 1)) + pauliz_n_qubit.append(self.paulizs[idx]) + + return pauliz_n_qubit + + def dot_product(self, alpha: RealVector) -> npt.NDArray[np.complex128]: + """ + Computes the standard dot product of `alpha` with the paulis. + + Args: + alpha (RealVector): The Pauli Z coefficients. + + Returns: + np.ndarray: Sum of element-wise multiplication of `alpha` + and `self.paulizs`. + + Raises: + ValueError: If `alpha` and `self.paulizs` are incompatible. + """ + + if not is_sequence(alpha) or not all(is_numeric(a) for a in alpha): + msg = f'Expected a sequence of numbers, got {type(alpha)}.' + raise TypeError(msg) + + if len(alpha) != len(self): + msg = ( + 'Incorrect number of alpha values, expected ' + f'{len(self)}, got {len(alpha)}.' + ) + raise ValueError(msg) + + return np.array(np.sum([a * s for a, s in zip(alpha, self.paulizs)], 0)) + + @staticmethod + def from_string( + pauliz_string: str, + ) -> npt.NDArray[np.complex128] | list[npt.NDArray[np.complex128]]: + """ + Construct Pauli Z matrices from a string description. + + Args: + pauliz_string (str): A string that describes the desired matrices. + This is a comma-seperated list of Pauli Z strings. + A Pauli Z string has the following regex pattern: [IZ]+ + + Returns: + np.ndarray | list[np.ndarray]: Either the single Pauli Z matrix + if only one is constructed, or the list of the constructed + Pauli Z matrices. + + Raises: + ValueError: if `pauliz_string` is invalid. + """ + + if not isinstance(pauliz_string, str): + msg = f'Expected str for pauliz_string, got {type(pauliz_string)}.' + raise TypeError(msg) + + pauliz_strings = [ + string.strip().upper() + for string in pauliz_string.split(',') + if len(string.strip()) > 0 + ] + + pauliz_matrices = [] + idx_dict = {'I': 0, 'Z': 1} + mat_dict = { + 'I': PauliZMatrices.I, + 'Z': PauliZMatrices.Z, + } + + for pauli_string in pauliz_strings: + if not all(char in 'IZ' for char in pauli_string): + raise ValueError('Invalid Pauli Z string.') + + if len(pauli_string) <= 6: + idx = 0 + for char in pauli_string: + idx *= 2 + idx += idx_dict[char] + pauliz_matrices.append(PauliZMatrices(len(pauli_string))[idx]) + else: + acm = mat_dict[pauli_string[0]] + for char in pauli_string[1:]: + acm = np.kron(acm, mat_dict[char]) + pauliz_matrices.append(acm) + + if len(pauliz_matrices) == 1: + return pauliz_matrices[0] + + return pauliz_matrices diff --git a/bqskit/qis/state/state.py b/bqskit/qis/state/state.py index bd8f93e88..d29d61521 100644 --- a/bqskit/qis/state/state.py +++ b/bqskit/qis/state/state.py @@ -433,4 +433,8 @@ def __repr__(self) -> str: return repr(self._vec) -StateLike = Union[StateVector, np.ndarray, Sequence[Union[int, float, complex]]] +StateLike = Union[ + StateVector, + npt.NDArray[np.complex128], + Sequence[Union[int, float, complex]], +] diff --git a/bqskit/qis/unitary/unitary.py b/bqskit/qis/unitary/unitary.py index cb4ada60f..b9f9966d5 100644 --- a/bqskit/qis/unitary/unitary.py +++ b/bqskit/qis/unitary/unitary.py @@ -7,6 +7,7 @@ from typing import Union import numpy as np +import numpy.typing as npt from bqskit.qis.unitary.meta import UnitaryMeta from bqskit.utils.typing import is_real_number @@ -53,7 +54,14 @@ def dim(self) -> int: if hasattr(self, '_dim'): return self._dim - return int(np.prod(self.radixes)) + # return int(np.prod(self.radixes)) + # Above line removed due to failure to handle overflow and + # underflows for large dimensions. + + acm = 1 + for radix in self.radixes: + acm *= int(radix) + return acm @abc.abstractmethod def get_unitary(self, params: RealVector = []) -> UnitaryMatrix: @@ -151,4 +159,8 @@ def is_self_inverse(self, params: RealVector = []) -> bool: return np.allclose(unitary_matrix, hermitian_conjugate) -RealVector = Union[Sequence[float], np.ndarray] +RealVector = Union[ + Sequence[float], + npt.NDArray[np.float64], + npt.NDArray[np.float32], +] diff --git a/bqskit/qis/unitary/unitarymatrix.py b/bqskit/qis/unitary/unitarymatrix.py index 10d55780d..04fafc616 100644 --- a/bqskit/qis/unitary/unitarymatrix.py +++ b/bqskit/qis/unitary/unitarymatrix.py @@ -199,6 +199,22 @@ def otimes(self, *utrys: UnitaryLike) -> UnitaryMatrix: return UnitaryMatrix(utry_acm, radixes_acm) + def ipower(self, power: int) -> UnitaryMatrix: + """ + Calculate the integer power of this unitary. + + Args: + power (int): The integer power to raise the unitary to. + + Returns: + UnitaryMatrix: The resulting unitary matrix. + """ + if power < 0: + mat = np.linalg.matrix_power(self.dagger, -power) + else: + mat = np.linalg.matrix_power(self, power) + return UnitaryMatrix(mat, self.radixes) + def get_unitary(self, params: RealVector = []) -> UnitaryMatrix: """Return the same object, satisfies the :class:`Unitary` API.""" return self @@ -232,6 +248,23 @@ def get_distance_from(self, other: UnitaryLike, degree: int = 2) -> float: dist = np.power(1 - (frac ** degree), 1.0 / degree) return dist if dist > 0.0 else 0.0 + def isclose(self, other: UnitaryLike, tol: float = 1e-6) -> bool: + """ + Check if `self` is approximately equal to `other` upto global phase. + + Args: + other (UnitaryLike): The unitary to compare to. + + tol (float): The numerical precision of the check. + + Returns: + bool: True if `self` is close to `other`. + + See Also: + - :func:`get_distance_from` for the error function used. + """ + return self.get_distance_from(other) < tol + def get_statevector(self, in_state: StateLike) -> StateVector: """ Calculate the output state after applying this unitary to `in_state`. @@ -507,6 +540,11 @@ def __hash__(self) -> int: UnitaryLike = Union[ UnitaryMatrix, - np.ndarray, + npt.NDArray[np.complex128], + npt.NDArray[np.complex64], + npt.NDArray[np.int64], + npt.NDArray[np.int32], + npt.NDArray[np.float64], + npt.NDArray[np.float32], Sequence[Sequence[Union[int, float, complex]]], ] diff --git a/bqskit/runtime/__init__.py b/bqskit/runtime/__init__.py index 31764cb46..75e3691c7 100644 --- a/bqskit/runtime/__init__.py +++ b/bqskit/runtime/__init__.py @@ -70,10 +70,8 @@ :class:`RuntimeHandle`, which you can use to submit, map, wait on, and cancel tasks in the execution environment. -For more information on how to design a custom pass, see this (TODO, sorry, -you can look at the source code of existing -`passes `_ -for a good example for the time being). +For more information on how to design a custom pass, see the following +guide: :doc:`guides/custompass.md`. .. autosummary:: :toctree: autogen @@ -98,23 +96,30 @@ from typing import Any from typing import Callable from typing import Protocol +from typing import Sequence from typing import TYPE_CHECKING -# Enable low-level fault handling: system crashes print a minimal trace. -faulthandler.enable() +if TYPE_CHECKING: + from bqskit.runtime.future import RuntimeFuture -# Disable multi-threading in BLAS libraries. -os.environ['OMP_NUM_THREADS'] = '1' -os.environ['OPENBLAS_NUM_THREADS'] = '1' -os.environ['MKL_NUM_THREADS'] = '1' -os.environ['NUMEXPR_NUM_THREADS'] = '1' -os.environ['VECLIB_MAXIMUM_THREADS'] = '1' +# Enable low-level fault handling: system crashes print a minimal trace. +faulthandler.enable() os.environ['RUST_BACKTRACE'] = '1' -if TYPE_CHECKING: - from bqskit.runtime.future import RuntimeFuture +# Control multi-threading in BLAS libraries. +def set_blas_thread_counts(i: int = 1) -> None: + """ + Control number of threads used by numpy and others. + + Must be called before any numpy or other BLAS libraries are loaded. + """ + os.environ['OMP_NUM_THREADS'] = str(i) + os.environ['OPENBLAS_NUM_THREADS'] = str(i) + os.environ['MKL_NUM_THREADS'] = str(i) + os.environ['NUMEXPR_NUM_THREADS'] = str(i) + os.environ['VECLIB_MAXIMUM_THREADS'] = str(i) class RuntimeHandle(Protocol): @@ -137,18 +142,129 @@ def submit( self, fn: Callable[..., Any], *args: Any, + task_name: str | None = None, + log_context: dict[str, str] = {}, **kwargs: Any, ) -> RuntimeFuture: - """Submit a `fn` to the runtime.""" + """ + Submit a function to the runtime for execution. + + This method schedules the function `fn` to be executed by the + runtime with the provided arguments `args` and keyword arguments + `kwargs`. The execution may happen asynchronously. + + Args: + fn (Callable[..., Any]): The function to be executed. + + *args (Any): Variable length argument list to be passed to + the function `fn`. + + task_name (str | None): An optional name for the task, which + can be used for logging or tracking purposes. Defaults to + None, which will use the function name as the task name. + + log_context (dict[str, str]): A dictionary containing logging + context information. All log messages produced by the fn + and any children tasks will contain this context if the + appropriate level (logging.DEBUG) is set on the logger. + Defaults to an empty dictionary for no added context. + + **kwargs (Any): Arbitrary keyword arguments to be passed to + the function `fn`. + + Returns: + RuntimeFuture: An object representing the future result of + the function execution. This can be used to retrieve the + result by `await`ing it. + + Example: + >>> from bqskit.runtime import get_runtime + >>> + >>> def add(x, y): + ... return x + y + >>> + >>> future = get_runtime().submit(add, 1, 2) + >>> result = await future + >>> print(result) + 3 + + See Also: + - :func:`map` for submitting multiple tasks in parallel. + - :func:`cancel` for cancelling tasks. + - :class:`~bqskit.runtime.future.RuntimeFuture` for more + information on how to interact with the future object. + """ ... def map( self, fn: Callable[..., Any], *args: Any, + task_name: Sequence[str | None] | str | None = None, + log_context: Sequence[dict[str, str]] | dict[str, str] = {}, **kwargs: Any, ) -> RuntimeFuture: - """Map `fn` over the input arguments distributed across the runtime.""" + """ + Map a function over a sequence of arguments and execute in parallel. + + This method schedules the function `fn` to be executed by the runtime + for each set of arguments provided in `args`. Each invocation of `fn` + will be executed potentially in parallel, depending on the runtime's + capabilities and current load. + + Args: + fn (Callable[..., Any]): The function to be executed. + + *args (Any): Variable length argument list to be passed to + the function `fn`. Each argument is expected to be a + sequence of arguments to be passed to a separate + invocation. The sequences should be of equal length. + + task_name (Sequence[str | None] | str | None): An optional + name for the task group, which can be used for logging + or tracking purposes. Defaults to None, which will use + the function name as the task name. If a string is + provided, it will be used as the prefix for all task + names. If a sequence of strings is provided, each task + will be named with the corresponding string in the + sequence. + + log_context (Sequence[dict[str, str]]) | dict[str, str]): A + dictionary containing logging context information. All + log messages produced by the `fn` and any children tasks + will contain this context if the appropriate level + (logging.DEBUG) is set on the logger. Defaults to an + empty dictionary for no added context. Can be a sequence + of contexts, one for each task, or a single context to be + used for all tasks. + + **kwargs (Any): Arbitrary keyword arguments to be passed to + each invocation of the function `fn`. + + Returns: + RuntimeFuture: An object representing the future result of + the function executions. This can be used to retrieve the + results by `await`ing it, which will return a list. + + Example: + >>> from bqskit.runtime import get_runtime + >>> + >>> def add(x, y): + ... return x + y + >>> + >>> args_list = [(1, 2, 3), (4, 5, 6)] + >>> future = get_runtime().map(add, *args_list) + >>> results = await future + >>> print(results) + [5, 7, 9] + + See Also: + - :func:`submit` for submitting a single task. + - :func:`cancel` for cancelling tasks. + - :func:`next` for retrieving results incrementally. + - :class:`~bqskit.runtime.future.RuntimeFuture` for more + information on how to interact with the future object. + """ ... def cancel(self, future: RuntimeFuture) -> None: diff --git a/bqskit/runtime/attached.py b/bqskit/runtime/attached.py index 6cd6a6a98..f68be0ee3 100644 --- a/bqskit/runtime/attached.py +++ b/bqskit/runtime/attached.py @@ -15,6 +15,9 @@ from bqskit.runtime.direction import MessageDirection +_logger = logging.getLogger(__name__) + + class AttachedServer(DetachedServer): """ BQSKit Runtime Server in attached mode. @@ -33,6 +36,8 @@ def __init__( num_workers: int = -1, port: int = default_server_port, worker_port: int = default_worker_port, + log_level: int = logging.WARNING, + num_blas_threads: int = 1, ) -> None: """ Create a server with `num_workers` workers. @@ -49,7 +54,24 @@ def __init__( worker_port (int): The port this server will listen for workers on. Default can be found in the :obj:`~bqskit.runtime.default_worker_port` global variable. + + log_level (int): The logging level for the server and workers. + (Default: logging.WARNING). + + num_blas_threads (int): The number of threads to use in BLAS + libraries. (Default: 1). """ + # Initialize runtime logging + logging.getLogger().setLevel(log_level) + _handler = logging.StreamHandler() + _handler.setLevel(0) + _fmt_header = '%(asctime)s.%(msecs)03d - %(levelname)-8s |' + _fmt_message = ' %(module)s: %(message)s' + _fmt = _fmt_header + _fmt_message + _formatter = logging.Formatter(_fmt, '%H:%M:%S') + _handler.setFormatter(_formatter) + logging.getLogger().addHandler(_handler) + ServerBase.__init__(self) # See DetachedServer for more info on the following fields: @@ -59,9 +81,6 @@ def __init__( self.mailboxes: dict[int, ServerMailbox] = {} self.mailbox_counter = 0 - # Start workers - self.spawn_workers(num_workers, worker_port) - # Connect to client client_conn = self.listen_once('localhost', port) self.clients[client_conn] = set() @@ -70,24 +89,23 @@ def __init__( selectors.EVENT_READ, MessageDirection.CLIENT, ) - self.logger.info('Connected to client.') + _logger.info('Connected to client.') + + # Start workers + self.spawn_workers( + num_workers, + worker_port, + log_level, + num_blas_threads, + ) def handle_disconnect(self, conn: Connection) -> None: """A client disconnect in attached mode is equal to a shutdown.""" self.handle_shutdown() -def start_attached_server( - num_workers: int, - log_level: int, - **kwargs: Any, -) -> None: +def start_attached_server(num_workers: int, **kwargs: Any) -> None: """Start a runtime server in attached mode.""" - # Initialize runtime logging - _logger = logging.getLogger('bqskit-runtime') - _logger.setLevel(log_level) - _logger.addHandler(logging.StreamHandler()) - # Initialize the server server = AttachedServer(num_workers, **kwargs) diff --git a/bqskit/runtime/base.py b/bqskit/runtime/base.py index 17cdf2747..bcc9f03fc 100644 --- a/bqskit/runtime/base.py +++ b/bqskit/runtime/base.py @@ -25,6 +25,7 @@ from bqskit.runtime import default_manager_port from bqskit.runtime import default_worker_port +from bqskit.runtime import set_blas_thread_counts from bqskit.runtime.address import RuntimeAddress from bqskit.runtime.direction import MessageDirection from bqskit.runtime.message import RuntimeMessage @@ -33,56 +34,88 @@ from bqskit.runtime.worker import start_worker +_logger = logging.getLogger(__name__) + + class RuntimeEmployee: """Data structure for a boss's view of an employee.""" def __init__( self, + id: int, conn: Connection, total_workers: int, process: Process | None = None, - num_tasks: int = 0, + is_manager: bool = False, ) -> None: """Construct an employee with all resources idle.""" + + self.id = id + """ + The ID of the employee. + + If this is a worker, then their unique worker id. If this is a manager, + then their local id. + """ + self.conn: Connection = conn self.total_workers = total_workers self.process = process - self.num_tasks = num_tasks + self.num_tasks = 0 self.num_idle_workers = total_workers + self.is_manager = is_manager - def shutdown(self) -> None: - """Shutdown the employee.""" + self.submit_cache: list[tuple[RuntimeAddress, int]] = [] + """ + Tracks recently submitted tasks by id and count. + + This is used to adjust the idle worker count when the employee sends a + waiting message. + """ + + def initiate_shutdown(self) -> None: + """Instruct employee to shutdown.""" try: self.conn.send((RuntimeMessage.SHUTDOWN, None)) except Exception: pass + def complete_shutdown(self) -> None: + """Ensure employee is shutdown and clean up resources.""" if self.process is not None: self.process.join() self.process = None self.conn.close() + def shutdown(self) -> None: + """Initiate and complete shutdown.""" + self.initiate_shutdown() + self.complete_shutdown() + + @property + def recipient_string(self) -> str: + """Return a string representation of the employee.""" + return f'{"Manager" if self.is_manager else "Worker"} {self.id}' + @property def has_idle_resources(self) -> bool: return self.num_idle_workers > 0 + def get_num_of_tasks_sent_since( + self, + read_receipt: RuntimeAddress | None, + ) -> int: + """Return the number of tasks sent since the read receipt.""" + if read_receipt is None: + return sum(count for _, count in self.submit_cache) -def send_outgoing(node: ServerBase) -> None: - """Outgoing thread forwards messages as they are created.""" - while True: - outgoing = node.outgoing.get() - - if not node.running: - # NodeBase's handle_shutdown will put a dummy value in the - # queue to wake the thread up so it can exit safely. - # Hence the node.running check now rather than in the - # while condition. - break + for i, (addr, _) in enumerate(self.submit_cache): + if addr == read_receipt: + self.submit_cache = self.submit_cache[i:] + return sum(count for _, count in self.submit_cache[1:]) - outgoing[0].send((outgoing[1], outgoing[2])) - node.logger.debug(f'Sent message {outgoing[1].name}.') - node.logger.log(1, f'{outgoing[2]}\n') + raise RuntimeError('Read receipt not found in submit cache.') def sigint_handler(signum: int, _: FrameType | None, node: ServerBase) -> None: @@ -92,7 +125,7 @@ def sigint_handler(signum: int, _: FrameType | None, node: ServerBase) -> None: node.running = False node.terminate_hotline.send(b'\0') - node.logger.info('Server interrupted.') + _logger.info('Server interrupted.') class ServerBase: @@ -122,24 +155,24 @@ def __init__(self) -> None: self.sel.register(p, selectors.EVENT_READ, MessageDirection.SIGNAL) """Terminate hotline is used to unblock select while running.""" - self.logger = logging.getLogger('bqskit-runtime') - """Logger used to print operational log messages.""" - self.employees: list[RuntimeEmployee] = [] """Tracks this node's employees, which are managers or workers.""" self.conn_to_employee_dict: dict[Connection, RuntimeEmployee] = {} """Used to find the employee associated with a message.""" + # Servers do not need blas threads + set_blas_thread_counts(1) + # Safely and immediately exit on interrupt signals handle = functools.partial(sigint_handler, node=self) signal.signal(signal.SIGINT, handle) # Start outgoing thread self.outgoing: Queue[tuple[Connection, RuntimeMessage, Any]] = Queue() - self.outgoing_thread = Thread(target=send_outgoing, args=(self,)) + self.outgoing_thread = Thread(target=self.send_outgoing, daemon=True) self.outgoing_thread.start() - self.logger.info('Started outgoing thread.') + _logger.info('Started outgoing thread.') def connect_to_managers(self, ipports: Sequence[tuple[str, int]]) -> None: """Connect to all managers given by endpoints in `ipports`.""" @@ -155,26 +188,33 @@ def connect_to_managers(self, ipports: Sequence[tuple[str, int]]) -> None: self.upper_id_bound, ) manager_conns.append(self.connect_to_manager(ip, port, lb, ub)) - self.logger.info(f'Connected to manager {i} at {ip}:{port}.') - self.logger.debug(f'Gave bounds {lb=} and {ub=} to manager {i}.') + _logger.info(f'Connected to manager {i} at {ip}:{port}.') + _logger.debug(f'Gave bounds {lb=} and {ub=} to manager {i}.') # Wait for started messages from all managers and register them self.total_workers = 0 for i, conn in enumerate(manager_conns): msg, num_workers = conn.recv() assert msg == RuntimeMessage.STARTED - self.employees.append(RuntimeEmployee(conn, num_workers)) + self.employees.append( + RuntimeEmployee( + i, + conn, + num_workers, + is_manager=True, + ), + ) self.conn_to_employee_dict[conn] = self.employees[-1] self.sel.register( conn, selectors.EVENT_READ, MessageDirection.BELOW, ) - self.logger.info(f'Registered manager {i} with {num_workers=}.') + _logger.info(f'Registered manager {i} with {num_workers=}.') self.total_workers += num_workers self.num_idle_workers = self.total_workers - self.logger.info(f'Node has {self.total_workers} total workers.') + _logger.info(f'Node has {self.total_workers} total workers.') def connect_to_manager( self, @@ -216,6 +256,8 @@ def spawn_workers( self, num_workers: int = -1, port: int = default_worker_port, + logging_level: int = logging.WARNING, + num_blas_threads: int = 1, ) -> None: """ Spawn worker processes. @@ -228,6 +270,11 @@ def spawn_workers( port (int): The port this server will listen for workers on. Default can be found in the :obj:`~bqskit.runtime.default_worker_port` global variable. + + logging_level (int): The logging level for the workers. + + num_blas_threads (int): The number of threads to use in BLAS + libraries. (Default: 1). """ if num_workers == -1: oscount = os.cpu_count() @@ -240,9 +287,17 @@ def spawn_workers( procs = {} for i in range(num_workers): w_id = self.lower_id_bound + i - procs[w_id] = Process(target=start_worker, args=(w_id, port)) + procs[w_id] = Process( + target=start_worker, + args=(w_id, port), + kwargs={ + 'logging_level': logging_level, + 'num_blas_threads': num_blas_threads, + }, + ) + procs[w_id].daemon = True procs[w_id].start() - self.logger.debug(f'Stated worker process {i}.') + _logger.debug(f'Stated worker process {i}.') # Listen for the worker connections family = 'AF_INET' if sys.platform == 'win32' else None @@ -255,7 +310,7 @@ def spawn_workers( for i, conn in enumerate(conns): msg, w_id = conn.recv() assert msg == RuntimeMessage.STARTED - employee = RuntimeEmployee(conn, 1, procs[w_id]) + employee = RuntimeEmployee(w_id, conn, 1, procs[w_id]) temp_reorder[w_id - self.lower_id_bound] = employee self.conn_to_employee_dict[conn] = employee @@ -264,18 +319,18 @@ def spawn_workers( self.employees.append(temp_reorder[i]) # Register employee communication - for i, employee in enumerate(self.employees): + for employee in self.employees: self.sel.register( employee.conn, selectors.EVENT_READ, MessageDirection.BELOW, ) - self.logger.info(f'Registered worker {i}.') + _logger.debug(f'Registered worker {employee.id}.') self.step_size = 1 self.total_workers = num_workers self.num_idle_workers = num_workers - self.logger.info(f'Node has spawned {num_workers} workers.') + _logger.info(f'Node has spawned {num_workers} workers.') def connect_to_workers( self, @@ -298,7 +353,7 @@ def connect_to_workers( oscount = os.cpu_count() num_workers = oscount if oscount else 1 - self.logger.info(f'Expecting {num_workers} worker connections.') + _logger.info(f'Expecting {num_workers} worker connections.') if self.lower_id_bound + num_workers >= self.upper_id_bound: raise RuntimeError('Insufficient id range for workers.') @@ -312,25 +367,25 @@ def connect_to_workers( for i, conn in enumerate(conns): w_id = self.lower_id_bound + i self.outgoing.put((conn, RuntimeMessage.STARTED, w_id)) - employee = RuntimeEmployee(conn, 1) + employee = RuntimeEmployee(w_id, conn, 1) self.employees.append(employee) self.conn_to_employee_dict[conn] = employee # Register employee communication - for i, employee in enumerate(self.employees): - w_id = self.lower_id_bound + i + for employee in self.employees: + w_id = employee.id assert employee.conn.recv() == (RuntimeMessage.STARTED, w_id) self.sel.register( employee.conn, selectors.EVENT_READ, MessageDirection.BELOW, ) - self.logger.info(f'Registered worker {i}.') + _logger.info(f'Registered worker {w_id}.') self.step_size = 1 self.total_workers = num_workers self.num_idle_workers = num_workers - self.logger.info(f'Node has connected to {num_workers} workers.') + _logger.info(f'Node has connected to {num_workers} workers.') def listen_once(self, ip: str, port: int) -> Connection: """Listen on `ip`:`port` for a connection and return on first one.""" @@ -340,9 +395,42 @@ def listen_once(self, ip: str, port: int) -> Connection: listener.close() return conn + def send_outgoing(self) -> None: + """Outgoing thread forwards messages as they are created.""" + while True: + outgoing = self.outgoing.get() + + if not self.running: + # NodeBase's handle_shutdown will put a dummy value in the + # queue to wake the thread up so it can exit safely. + # Hence the node.running check now rather than in the + # while condition. + break + + if outgoing[0].closed: + continue + + try: + outgoing[0].send((outgoing[1], outgoing[2])) + except (EOFError, ConnectionResetError): + self.handle_disconnect(outgoing[0]) + _logger.warning('Connection reset while sending message.') + continue + + if _logger.isEnabledFor(logging.DEBUG): + to = self.get_to_string(outgoing[0]) + _logger.debug(f'Sent message {outgoing[1].name} to {to}.') + + if outgoing[1] == RuntimeMessage.SUBMIT_BATCH: + _logger.log(1, f'[{outgoing[2][0]}] * {len(outgoing[2])}\n') + else: + _logger.log(1, f'{outgoing[2]}\n') + + self.outgoing.task_done() + def run(self) -> None: """Main loop.""" - self.logger.info(f'{self.__class__.__name__} running...') + _logger.info(f'{self.__class__.__name__} running...') try: while self.running: @@ -356,7 +444,7 @@ def run(self) -> None: # If interrupted by signal, shutdown and exit if direction == MessageDirection.SIGNAL: - self.logger.debug('Received interrupt signal.') + _logger.debug('Received interrupt signal.') self.handle_shutdown() return @@ -367,8 +455,11 @@ def run(self) -> None: self.handle_disconnect(conn) continue log = f'Received message {msg.name} from {direction.name}.' - self.logger.debug(log) - self.logger.log(1, f'{payload}\n') + _logger.debug(log) + if msg == RuntimeMessage.SUBMIT_BATCH: + _logger.log(1, f'[{payload[0]}] * {len(payload)}\n') + else: + _logger.log(1, f'{payload}\n') # Handle message self.handle_message(msg, direction, conn, payload) @@ -376,7 +467,7 @@ def run(self) -> None: except Exception: exc_info = sys.exc_info() error_str = ''.join(traceback.format_exception(*exc_info)) - self.logger.error(error_str) + _logger.error(error_str) self.handle_system_error(error_str) finally: @@ -412,27 +503,35 @@ def handle_system_error(self, error_str: str) -> None: RuntimeTask's coroutine code. """ + @abc.abstractmethod + def get_to_string(self, conn: Connection) -> str: + """Return a string representation of the connection.""" + def handle_shutdown(self) -> None: """Shutdown the node and release resources.""" # Stop running - self.logger.info('Shutting down node.') + _logger.info('Shutting down node.') self.running = False # Instruct employees to shutdown for employee in self.employees: - employee.shutdown() + employee.initiate_shutdown() + + for employee in self.employees: + employee.complete_shutdown() + self.employees.clear() - self.logger.debug('Shutdown employees.') + _logger.debug('Shutdown employees.') # Close selector self.sel.close() - self.logger.debug('Cleared selector.') + _logger.debug('Cleared selector.') # Close outgoing thread if self.outgoing_thread.is_alive(): self.outgoing.put(b'\0') # type: ignore self.outgoing_thread.join() - self.logger.debug('Joined outgoing thread.') + _logger.debug('Joined outgoing thread.') assert not self.outgoing_thread.is_alive() def handle_disconnect(self, conn: Connection) -> None: @@ -444,10 +543,6 @@ def handle_disconnect(self, conn: Connection) -> None: if conn in self.conn_to_employee_dict: self.handle_shutdown() - def __del__(self) -> None: - """Ensure resources are cleaned up.""" - self.handle_shutdown() - def assign_tasks( self, tasks: Sequence[RuntimeTask], @@ -513,10 +608,13 @@ def schedule_tasks(self, tasks: Sequence[RuntimeTask]) -> None: """Schedule tasks between this node's employees.""" if len(tasks) == 0: return - - assignments = self.assign_tasks(tasks) - - for e, assignment in zip(self.employees, assignments): + assignments = zip(self.employees, self.assign_tasks(tasks)) + sorted_assignments = sorted( + assignments, + key=lambda x: x[0].num_idle_workers, + reverse=True, + ) # Employees with the most idle workers get assignments first + for e, assignment in sorted_assignments: num_tasks = len(assignment) if num_tasks == 0: @@ -526,6 +624,7 @@ def schedule_tasks(self, tasks: Sequence[RuntimeTask]) -> None: e.num_tasks += num_tasks e.num_idle_workers -= min(num_tasks, e.num_idle_workers) + e.submit_cache.append((assignment[0].unique_id, num_tasks)) self.num_idle_workers = sum(e.num_idle_workers for e in self.employees) @@ -549,28 +648,41 @@ def get_employee_responsible_for(self, worker_id: int) -> RuntimeEmployee: employee_id = (worker_id - self.lower_id_bound) // self.step_size return self.employees[employee_id] - def broadcast_cancel(self, addr: RuntimeAddress) -> None: + def broadcast(self, msg: RuntimeMessage, payload: Any) -> None: """Broadcast a cancel message to my employees.""" for employee in self.employees: - self.outgoing.put((employee.conn, RuntimeMessage.CANCEL, addr)) + self.outgoing.put((employee.conn, msg, payload)) + + def handle_importpath(self, paths: list[str]) -> None: + """Update the system path with the given paths.""" + for path in paths: + if path not in sys.path: + sys.path.append(path) + self.broadcast(RuntimeMessage.IMPORTPATH, paths) - def handle_waiting(self, conn: Connection, new_idle_count: int) -> None: + def handle_waiting( + self, + conn: Connection, + new_idle_count: int, + read_receipt: RuntimeAddress | None, + ) -> None: """ Record that an employee is idle with nothing to do. - There is a race condition here that is allowed. If an employee - sends a waiting message at the same time that this sends it a - task, it will still be marked waiting even though it is running - a task. We allow this for two reasons. First, the consequences are - minimal: this situation can only lead to one extra task assigned - to the worker that could otherwise go to a truly idle worker. - Second, it is unlikely in the common BQSKit workflows, which have - wide and shallow task graphs and each leaf task can require seconds - of runtime. + There is a race condition that is corrected here. If an employee sends a + waiting message at the same time that its boss sends it a task, the + boss's idle count will eventually be incorrect. To fix this, every + waiting message sent by an employee is accompanied by a read receipt of + the latest batch of tasks it has processed. The boss can then adjust the + idle count by the number of tasks sent since the read receipt. """ - old_count = self.conn_to_employee_dict[conn].num_idle_workers - self.conn_to_employee_dict[conn].num_idle_workers = new_idle_count - self.num_idle_workers += (new_idle_count - old_count) + employee = self.conn_to_employee_dict[conn] + unaccounted_task = employee.get_num_of_tasks_sent_since(read_receipt) + adjusted_idle_count = max(new_idle_count - unaccounted_task, 0) + + old_count = employee.num_idle_workers + employee.num_idle_workers = adjusted_idle_count + self.num_idle_workers += (adjusted_idle_count - old_count) assert 0 <= self.num_idle_workers <= self.total_workers diff --git a/bqskit/runtime/detached.py b/bqskit/runtime/detached.py index 3b817ef1a..ea32afbd6 100644 --- a/bqskit/runtime/detached.py +++ b/bqskit/runtime/detached.py @@ -8,17 +8,16 @@ import time import uuid from dataclasses import dataclass -from logging import LogRecord from multiprocessing.connection import Connection from multiprocessing.connection import Listener from threading import Thread from typing import Any from typing import cast from typing import List +from typing import Optional from typing import Sequence +from typing import Tuple -from bqskit.compiler.status import CompilationStatus -from bqskit.compiler.task import CompilationTask from bqskit.runtime import default_server_port from bqskit.runtime.address import RuntimeAddress from bqskit.runtime.base import import_tests_package @@ -30,25 +29,7 @@ from bqskit.runtime.task import RuntimeTask -def listen(server: DetachedServer, port: int) -> None: - """Listening thread listens for client connections.""" - listener = Listener(('0.0.0.0', port)) - while server.running: - client = listener.accept() - - if server.running: - # We check again that the server is running before registering - # the client because dummy data is sent to unblock - # listener.accept() during server shutdown - server.clients[client] = set() - server.sel.register( - client, - selectors.EVENT_READ, - MessageDirection.CLIENT, - ) - server.logger.debug('Connected and registered new client.') - - listener.close() +_logger = logging.getLogger(__name__) @dataclass @@ -115,9 +96,30 @@ def __init__( # Start client listener self.port = port - self.listen_thread = Thread(target=listen, args=(self, port)) + self.listen_thread = Thread(target=self.listen, args=(port,)) + self.listen_thread.daemon = True self.listen_thread.start() - self.logger.info(f'Started client listener on port {self.port}.') + _logger.info(f'Started client listener on port {self.port}.') + + def listen(self, port: int) -> None: + """Listening thread listens for client connections.""" + listener = Listener(('0.0.0.0', port)) + while self.running: + client = listener.accept() + + if self.running: + # We check again that the server is running before registering + # the client because dummy data is sent to unblock + # listener.accept() during server shutdown + self.clients[client] = set() + self.sel.register( + client, + selectors.EVENT_READ, + MessageDirection.CLIENT, + ) + _logger.debug('Connected and registered new client.') + + listener.close() def handle_message( self, @@ -130,14 +132,14 @@ def handle_message( if direction == MessageDirection.CLIENT: if msg == RuntimeMessage.CONNECT: - pass + paths = cast(List[str], payload) + self.handle_connect(conn, paths) elif msg == RuntimeMessage.DISCONNECT: self.handle_disconnect(conn) elif msg == RuntimeMessage.SUBMIT: - ctask = cast(CompilationTask, payload) - self.handle_new_comp_task(conn, ctask) + self.handle_new_comp_task(conn, payload) elif msg == RuntimeMessage.REQUEST: request = cast(uuid.UUID, payload) @@ -176,25 +178,34 @@ def handle_message( self.handle_log(payload) elif msg == RuntimeMessage.CANCEL: - self.broadcast_cancel(payload) + self.broadcast(msg, payload) elif msg == RuntimeMessage.SHUTDOWN: self.handle_shutdown() elif msg == RuntimeMessage.WAITING: - num_idle = cast(int, payload) - self.handle_waiting(conn, num_idle) + p = cast(Tuple[int, Optional[RuntimeAddress]], payload) + num_idle, read_receipt = p + self.handle_waiting(conn, num_idle, read_receipt) elif msg == RuntimeMessage.UPDATE: task_diff = cast(int, payload) self.conn_to_employee_dict[conn].num_tasks += task_diff + elif msg == RuntimeMessage.COMMUNICATE: + self.broadcast(msg, payload) + else: raise RuntimeError(f'Unexpected message type: {msg.name}') else: raise RuntimeError(f'Unexpected message from {direction.name}.') + def handle_connect(self, conn: Connection, paths: list[str]) -> None: + """Handle a client connection request.""" + self.handle_importpath(paths) + self.outgoing.put((conn, RuntimeMessage.READY, None)) + def handle_system_error(self, error_str: str) -> None: """ Handle an error in runtime code as opposed to client code. @@ -208,6 +219,13 @@ def handle_system_error(self, error_str: str) -> None: # Sleep to ensure clients receive error message before shutdown time.sleep(1) + def get_to_string(self, conn: Connection) -> str: + """Return a string representation of the connection.""" + if conn in self.clients: + return 'CLIENT' + + return self.conn_to_employee_dict[conn].recipient_string + def handle_shutdown(self) -> None: """Shutdown the runtime.""" super().handle_shutdown() @@ -219,7 +237,7 @@ def handle_shutdown(self) -> None: except Exception: pass self.clients.clear() - self.logger.debug('Cleared clients.') + _logger.debug('Cleared clients.') # Close listener (hasattr checked for attachedserver shutdown) if hasattr(self, 'listen_thread') and self.listen_thread.is_alive(): @@ -231,27 +249,39 @@ def handle_shutdown(self) -> None: dummy_socket.connect(('localhost', self.port)) dummy_socket.close() self.listen_thread.join() - self.logger.debug('Joined listening thread.') + _logger.debug('Joined listening thread.') def handle_disconnect(self, conn: Connection) -> None: """Disconnect a client connection from the runtime.""" super().handle_disconnect(conn) tasks = self.clients.pop(conn) + for task_id in tasks: self.handle_cancel_comp_task(task_id) - self.logger.info('Unregistered client.') + + tasks_to_pop = [] + for (task, (tid, other_conn)) in self.tasks.items(): + if other_conn == conn: + tasks_to_pop.append((task_id, tid)) + + for task_id, tid in tasks_to_pop: + self.tasks.pop(task_id) + self.mailbox_to_task_dict.pop(tid) + + _logger.info('Unregistered client.') def handle_new_comp_task( self, conn: Connection, - task: CompilationTask, + task: Any, # Explicitly not CompilationTask to avoid early import ) -> None: """Convert a :class:`CompilationTask` into an internal one.""" + from bqskit.compiler.task import CompilationTask mailbox_id = self._get_new_mailbox_id() self.tasks[task.task_id] = (mailbox_id, conn) self.mailbox_to_task_dict[mailbox_id] = task.task_id self.mailboxes[mailbox_id] = ServerMailbox() - self.logger.info(f'New CompilationTask: {task.task_id}.') + _logger.info(f'New CompilationTask: {task.task_id}.') self.clients[conn].add(task.task_id) @@ -279,7 +309,7 @@ def handle_request(self, conn: Connection, request: uuid.UUID) -> None: if box.ready: # If the result has already arrived, ship it to the client. - self.logger.info(f'Responding to request for task {request}.') + _logger.info(f'Responding to request for task {request}.') self.outgoing.put((conn, RuntimeMessage.RESULT, box.result)) self.mailboxes.pop(mailbox_id) self.clients[conn].remove(request) @@ -294,6 +324,7 @@ def handle_request(self, conn: Connection, request: uuid.UUID) -> None: def handle_status(self, conn: Connection, request: uuid.UUID) -> None: """Inform the client if the task is finished or not.""" + from bqskit.compiler.status import CompilationStatus if request not in self.clients[conn] or request not in self.tasks: # This task is unknown to the system m = (conn, RuntimeMessage.STATUS, CompilationStatus.UNKNOWN) @@ -309,7 +340,7 @@ def handle_status(self, conn: Connection, request: uuid.UUID) -> None: def handle_cancel_comp_task(self, request: uuid.UUID) -> None: """Cancel a compilation task in the system.""" - self.logger.info(f'Cancelling: {request}.') + _logger.info(f'Cancelling: {request}.') # Remove task from server data mailbox_id, client_conn = self.tasks[request] @@ -319,7 +350,7 @@ def handle_cancel_comp_task(self, request: uuid.UUID) -> None: # Forward internal cancel messages addr = RuntimeAddress(-1, mailbox_id, 0) - self.broadcast_cancel(addr) + self.broadcast(RuntimeMessage.CANCEL, addr) # Acknowledge the client's cancel request if not client_conn.closed: @@ -340,10 +371,10 @@ def handle_result(self, result: RuntimeResult) -> None: box = self.mailboxes[mailbox_id] box.result = result.result t_id = self.mailbox_to_task_dict[mailbox_id] - self.logger.info(f'Finished: {t_id}.') + _logger.info(f'Finished: {t_id}.') if box.client_waiting: - self.logger.info(f'Responding to request for task {t_id}.') + _logger.info(f'Responding to request for task {t_id}.') m = (self.tasks[t_id][1], RuntimeMessage.RESULT, box.result) self.outgoing.put(m) self.clients[self.tasks[t_id][1]].remove(t_id) @@ -365,12 +396,24 @@ def handle_error(self, error_payload: tuple[int, str]) -> None: raise RuntimeError(error_payload) tid = error_payload[0] + if tid not in self.mailbox_to_task_dict: + return # Silently discard errors from cancelled tasks + conn = self.tasks[self.mailbox_to_task_dict[tid]][1] self.outgoing.put((conn, RuntimeMessage.ERROR, error_payload[1])) - - def handle_log(self, log_payload: tuple[int, LogRecord]) -> None: + # TODO: Broadcast cancel to all tasks with compilation task id tid + # But avoid double broadcasting it. If the client crashes due to + # this error, which it may not, then we will quickly process + # a handle_disconnect and call the cancel anyways. We should + # still cancel here incase the client catches the error and + # resubmits a job. + + def handle_log(self, log_payload: tuple[int, bytes]) -> None: """Forward logs to appropriate client.""" tid = log_payload[0] + if tid not in self.mailbox_to_task_dict: + return # Silently discard logs from cancelled tasks + conn = self.tasks[self.mailbox_to_task_dict[tid]][1] self.outgoing.put((conn, RuntimeMessage.LOG, log_payload[1])) @@ -413,9 +456,16 @@ def start_server() -> None: ipports = parse_ipports(args.managers) # Set up logging - _logger = logging.getLogger('bqskit-runtime') - _logger.setLevel([30, 20, 10, 1][min(args.verbose, 3)]) - _logger.addHandler(logging.StreamHandler()) + log_level = [30, 20, 10, 1][min(args.verbose, 3)] + logging.getLogger().setLevel(log_level) + _handler = logging.StreamHandler() + _handler.setLevel(0) + _fmt_header = '%(asctime)s.%(msecs)03d - %(levelname)-8s |' + _fmt_message = ' %(module)s: %(message)s' + _fmt = _fmt_header + _fmt_message + _formatter = logging.Formatter(_fmt, '%H:%M:%S') + _handler.setFormatter(_formatter) + logging.getLogger().addHandler(_handler) # Import tests package recursively if args.import_tests: diff --git a/bqskit/runtime/future.py b/bqskit/runtime/future.py index 70f6ac2cc..ab69d5911 100644 --- a/bqskit/runtime/future.py +++ b/bqskit/runtime/future.py @@ -27,8 +27,8 @@ def __await__(self) -> Any: Informs the event loop which mailbox this is waiting on. """ - if self._next_flag: - return (yield self) + # if self._next_flag: + # return (yield self) return (yield self) diff --git a/bqskit/runtime/manager.py b/bqskit/runtime/manager.py index 7779f47ca..14827af9e 100644 --- a/bqskit/runtime/manager.py +++ b/bqskit/runtime/manager.py @@ -9,7 +9,9 @@ from typing import Any from typing import cast from typing import List +from typing import Optional from typing import Sequence +from typing import Tuple from bqskit.runtime import default_manager_port from bqskit.runtime import default_worker_port @@ -23,6 +25,9 @@ from bqskit.runtime.task import RuntimeTask +_logger = logging.getLogger(__name__) + + class Manager(ServerBase): """ BQSKit Runtime Manager. @@ -42,6 +47,8 @@ def __init__( ipports: list[tuple[str, int]] | None = None, worker_port: int = default_worker_port, only_connect: bool = False, + log_level: int = logging.WARNING, + num_blas_threads: int = 1, ) -> None: """ Create a manager instance in one of two ways: @@ -78,6 +85,16 @@ def __init__( only_connect (bool): If true, do not spawn workers, only connect to already spawned workers. + + log_level (int): The logging level for the manager and workers. + If `only_connect` is True, doesn't set worker's log level. + In that case, set the worker's log level when spawning them. + (Default: logging.WARNING). + + num_blas_threads (int): The number of threads to use in BLAS + libraries. If `only_connect` is True this is ignored. In + that case, set the thread count when spawning workers. + (Default: 1). """ super().__init__() @@ -95,24 +112,32 @@ def __init__( MessageDirection.ABOVE, ) - # Case 1: spawn and manage workers + # Case 1: spawn and/or manage workers if ipports is None: if only_connect: self.connect_to_workers(num_workers, worker_port) else: - self.spawn_workers(num_workers, worker_port) - - # Case 2: Connect to managers at ipports + self.spawn_workers( + num_workers, + worker_port, + log_level, + num_blas_threads, + ) + + # Case 2: Connect to detached managers at ipports else: self.connect_to_managers(ipports) # Track info on sent messages to reduce redundant messages: self.last_num_idle_sent_up = self.total_workers + # Track info on received messages to report read receipts: + self.most_recent_read_submit: RuntimeAddress | None = None + # Inform upstream we are starting msg = (self.upstream, RuntimeMessage.STARTED, self.total_workers) self.outgoing.put(msg) - self.logger.info('Sent start message upstream.') + _logger.info('Sent start message upstream.') def handle_message( self, @@ -126,25 +151,33 @@ def handle_message( if msg == RuntimeMessage.SUBMIT: rtask = cast(RuntimeTask, payload) + self.most_recent_read_submit = rtask.unique_id self.schedule_tasks([rtask]) - self.update_upstream_idle_workers() + # self.update_upstream_idle_workers() elif msg == RuntimeMessage.SUBMIT_BATCH: rtasks = cast(List[RuntimeTask], payload) + self.most_recent_read_submit = rtasks[0].unique_id self.schedule_tasks(rtasks) - self.update_upstream_idle_workers() + # self.update_upstream_idle_workers() elif msg == RuntimeMessage.RESULT: result = cast(RuntimeResult, payload) self.send_result_down(result) elif msg == RuntimeMessage.CANCEL: - addr = cast(RuntimeAddress, payload) - self.broadcast_cancel(addr) + self.broadcast(RuntimeMessage.CANCEL, payload) elif msg == RuntimeMessage.SHUTDOWN: self.handle_shutdown() + elif msg == RuntimeMessage.IMPORTPATH: + paths = cast(List[str], payload) + self.handle_importpath(paths) + + elif msg == RuntimeMessage.COMMUNICATE: + self.broadcast(RuntimeMessage.COMMUNICATE, payload) + else: raise RuntimeError(f'Unexpected message type: {msg.name}') @@ -153,20 +186,19 @@ def handle_message( if msg == RuntimeMessage.SUBMIT: rtask = cast(RuntimeTask, payload) self.send_up_or_schedule_tasks([rtask]) - self.update_upstream_idle_workers() elif msg == RuntimeMessage.SUBMIT_BATCH: rtasks = cast(List[RuntimeTask], payload) self.send_up_or_schedule_tasks(rtasks) - self.update_upstream_idle_workers() elif msg == RuntimeMessage.RESULT: result = cast(RuntimeResult, payload) self.handle_result_from_below(result) elif msg == RuntimeMessage.WAITING: - num_idle = cast(int, payload) - self.handle_waiting(conn, num_idle) + p = cast(Tuple[int, Optional[RuntimeAddress]], payload) + num_idle, read_receipt = p + self.handle_waiting(conn, num_idle, read_receipt) self.update_upstream_idle_workers() elif msg == RuntimeMessage.UPDATE: @@ -197,6 +229,13 @@ def handle_system_error(self, error_str: str) -> None: # If server has crashed then just exit pass + def get_to_string(self, conn: Connection) -> str: + """Return a string representation of the connection.""" + if conn == self.upstream: + return 'BOSS' + + return self.conn_to_employee_dict[conn].recipient_string + def handle_shutdown(self) -> None: """Shutdown the manager and clean up spawned processes.""" super().handle_shutdown() @@ -217,6 +256,7 @@ def send_up_or_schedule_tasks(self, tasks: Sequence[RuntimeTask]) -> None: if num_idle != 0: self.outgoing.put((self.upstream, RuntimeMessage.UPDATE, num_idle)) self.schedule_tasks(tasks[:num_idle]) + self.update_upstream_idle_workers() if len(tasks) > num_idle: self.outgoing.put(( @@ -244,7 +284,8 @@ def update_upstream_idle_workers(self) -> None: """Update the total number of idle workers upstream.""" if self.num_idle_workers != self.last_num_idle_sent_up: self.last_num_idle_sent_up = self.num_idle_workers - m = (self.upstream, RuntimeMessage.WAITING, self.num_idle_workers) + payload = (self.num_idle_workers, self.most_recent_read_submit) + m = (self.upstream, RuntimeMessage.WAITING, payload) self.outgoing.put(m) def handle_update(self, conn: Connection, task_diff: int) -> None: @@ -305,9 +346,16 @@ def start_manager() -> None: ipports = None if args.managers is None else parse_ipports(args.managers) # Set up logging - _logger = logging.getLogger('bqskit-runtime') - _logger.setLevel([30, 20, 10, 1][min(args.verbose, 3)]) - _logger.addHandler(logging.StreamHandler()) + log_level = [30, 20, 10, 1][min(args.verbose, 3)] + logging.getLogger().setLevel(log_level) + _handler = logging.StreamHandler() + _handler.setLevel(0) + _fmt_header = '%(asctime)s.%(msecs)03d - %(levelname)-8s |' + _fmt_message = ' %(module)s: %(message)s' + _fmt = _fmt_header + _fmt_message + _formatter = logging.Formatter(_fmt, '%H:%M:%S') + _handler.setFormatter(_formatter) + logging.getLogger().addHandler(_handler) # Import tests package recursively if args.import_tests: diff --git a/bqskit/runtime/message.py b/bqskit/runtime/message.py index 63f687048..d2585aef2 100644 --- a/bqskit/runtime/message.py +++ b/bqskit/runtime/message.py @@ -20,3 +20,6 @@ class RuntimeMessage(IntEnum): CANCEL = 11 WAITING = 12 UPDATE = 13 + IMPORTPATH = 14 + READY = 15 + COMMUNICATE = 16 diff --git a/bqskit/runtime/task.py b/bqskit/runtime/task.py index b55037a87..d8cef7855 100644 --- a/bqskit/runtime/task.py +++ b/bqskit/runtime/task.py @@ -6,6 +6,8 @@ from typing import Any from typing import Coroutine +import dill + from bqskit.runtime.address import RuntimeAddress @@ -34,18 +36,26 @@ def __init__( breadcrumbs: tuple[RuntimeAddress, ...], logging_level: int | None = None, max_logging_depth: int = -1, + task_name: str | None = None, + log_context: dict[str, str] = {}, ) -> None: """Create the task with a new id and return address.""" RuntimeTask.task_counter += 1 self.task_id = RuntimeTask.task_counter - self.fnargs = fnargs + self.serialized_fnargs = dill.dumps(fnargs) + self._fnargs: tuple[Any, Any, Any] | None = None + self._name = fnargs[0].__name__ if task_name is None else task_name """Tuple of function pointer, arguments, and keyword arguments.""" self.return_address = return_address - """Where the result of this task should be sent.""" + """ + Where the result of this task should be sent. + + This doubles as a unique system-wide id for the task. + """ - self.logging_level = logging_level + self.logging_level = logging_level or 0 """Logs with levels >= to this get emitted, if None always emit.""" self.comp_task_id = comp_task_id @@ -60,9 +70,6 @@ def __init__( self.coro: Coroutine[Any, Any, Any] | None = None """The coroutine containing this tasks code.""" - # self.send: Any = None - # """A register that both the coroutine and task have access to.""" - self.desired_box_id: int | None = None """When waiting on a mailbox, this stores that mailbox's id.""" @@ -72,6 +79,19 @@ def __init__( self.wake_on_next: bool = False """Set to true if this task should wake immediately on a result.""" + self.log_context: dict[str, str] = log_context + """Additional context to be logged with this task.""" + + self.msg_buffer: list[Any] = [] + + @property + def fnargs(self) -> tuple[Any, Any, Any]: + """Return the function pointer, arguments, and keyword arguments.""" + if self._fnargs is None: + self._fnargs = dill.loads(self.serialized_fnargs) + assert self._fnargs is not None # for type checker + return self._fnargs + def step(self, send_val: Any = None) -> Any: """Execute one step of the task.""" if self.coro is None: @@ -87,7 +107,9 @@ def step(self, send_val: Any = None) -> Any: self.max_logging_depth < 0 or len(self.breadcrumbs) <= self.max_logging_depth ): - logging.getLogger().setLevel(0) + logging.getLogger().setLevel(self.logging_level) + else: + logging.getLogger().setLevel(100) # Execute a task step to_return = self.coro.send(send_val) @@ -97,10 +119,30 @@ def step(self, send_val: Any = None) -> Any: return to_return + @property + def unique_id(self) -> RuntimeAddress: + """Return the task's system-wide unique id.""" + return self.return_address + def start(self) -> None: """Initialize the task.""" self.coro = self.run() + def cancel(self) -> None: + """Ask the coroutine to gracefully exit.""" + if self.coro is not None: + # If this call to "close" raises a RuntimeError, + # it is likely a blanket try/accept catching the + # error used to stop the coroutine, preventing + # it from stopping correctly. + try: + self.coro.close() + except ValueError: + # Coroutine is running and cannot be closed. + pass + else: + raise RuntimeError('Task was cancelled with None coroutine.') + async def run(self) -> Any: """Task coroutine wrapper.""" if inspect.iscoroutinefunction(self.fnargs[0]): @@ -110,3 +152,11 @@ async def run(self) -> Any: def is_descendant_of(self, addr: RuntimeAddress) -> bool: """Return true if `addr` identifies a parent (or this) task.""" return addr == self.return_address or addr in self.breadcrumbs + + def __str__(self) -> str: + """Return a string representation of the task.""" + return f'{self._name}' + + def __repr__(self) -> str: + """Return a string representation of the task.""" + return f'' diff --git a/bqskit/runtime/worker.py b/bqskit/runtime/worker.py index 5c4b3ca18..e61b13009 100644 --- a/bqskit/runtime/worker.py +++ b/bqskit/runtime/worker.py @@ -4,22 +4,27 @@ import argparse import logging import os +import pickle import signal import sys import time import traceback -from collections import OrderedDict from dataclasses import dataclass from multiprocessing import Process from multiprocessing.connection import Client from multiprocessing.connection import Connection -from multiprocessing.connection import wait +from queue import Empty +from queue import Queue +from threading import Lock +from threading import Thread from typing import Any from typing import Callable from typing import cast from typing import List +from typing import Sequence from bqskit.runtime import default_worker_port +from bqskit.runtime import set_blas_thread_counts from bqskit.runtime.address import RuntimeAddress from bqskit.runtime.future import RuntimeFuture from bqskit.runtime.message import RuntimeMessage @@ -27,31 +32,7 @@ from bqskit.runtime.task import RuntimeTask -class WorkerQueue(): - """The worker's task FIFO queue.""" - - def __init__(self) -> None: - """ - Initialize the worker queue. - - An OrderedDict is used to internally store the task. This prevents the - same task appearing multiple times in the queue, while also ensuring - O(1) operations. - """ - self._queue: OrderedDict[RuntimeAddress, None] = OrderedDict() - - def put(self, addr: RuntimeAddress) -> None: - """Enqueue a task by its address.""" - if addr not in self._queue: - self._queue[addr] = None - - def get(self) -> RuntimeAddress: - """Get the next task to run.""" - return self._queue.popitem(last=False)[0] - - def empty(self) -> bool: - """Check if the queue is empty.""" - return len(self._queue) == 0 +_logger = logging.getLogger(__name__) @dataclass @@ -126,11 +107,13 @@ class Worker: """ BQSKit Runtime's Worker. - BQSKit Runtime utilizes a single-threaded worker to accept, execute, + BQSKit Runtime utilizes a dual-threaded worker to accept, execute, pause, spawn, resume, and complete tasks in a custom event loop built with python's async await mechanisms. Each worker receives and sends tasks and results to the greater system through a single duplex - connection with a runtime server or manager. + connection with a runtime server or manager. One thread performs + work and sends outgoing messages, while the other thread handles + incoming messages. At start-up, the worker receives an ID and waits for its first task. An executing task may use the `submit` and `map` methods to spawn child @@ -178,16 +161,20 @@ def __init__(self, id: int, conn: Connection) -> None: self._id = id self._conn = conn - self._outgoing: list[tuple[RuntimeMessage, Any]] = [] - """Stores outgoing messages to be handled by the event loop.""" - self._tasks: dict[RuntimeAddress, RuntimeTask] = {} """Tracks all started, unfinished tasks on this worker.""" self._delayed_tasks: list[RuntimeTask] = [] - """Store all delayed tasks in LIFO order.""" + """ + Store all delayed tasks in LIFO order. + + Delayed tasks have no context and are stored (more-or-less) as a + function pointer together with the arguments. When it gets started, it + consumes much more memory, so we delay the task start until necessary + (at no cost) + """ - self._ready_task_ids: WorkerQueue = WorkerQueue() + self._ready_task_ids: Queue[RuntimeAddress] = Queue() """Tasks queued up for execution.""" self._cancelled_task_ids: set[RuntimeAddress] = set() @@ -196,7 +183,7 @@ def __init__(self, id: int, conn: Connection) -> None: self._active_task: RuntimeTask | None = None """The currently executing task if one is running.""" - self._running = False + self._running = True """Controls if the event loop is running.""" self._mailboxes: dict[int, WorkerMailbox] = {} @@ -208,71 +195,112 @@ def __init__(self, id: int, conn: Connection) -> None: self._cache: dict[str, Any] = {} """Local worker cache.""" - # Send out every emitted log message upstream + self.most_recent_read_submit: RuntimeAddress | None = None + """Tracks the most recently processed submit message from above.""" + + self.read_receipt_mutex = Lock() + """ + A lock to ensure waiting messages's read receipt is correct. + + This lock enforces atomic update of `most_recent_read_submit` and + task addition/enqueueing. This is necessary to ensure that the + idle status is always correct. + """ + + # Send out every client emitted log message upstream old_factory = logging.getLogRecordFactory() def record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord: record = old_factory(*args, **kwargs) - active_task = get_worker()._active_task - if active_task is not None: - lvl = active_task.logging_level - if lvl is None or lvl <= record.levelno: - tid = active_task.comp_task_id - self._outgoing.append((RuntimeMessage.LOG, (tid, record))) + active_task = self._active_task + if not record.name.startswith('bqskit.runtime'): + if active_task is not None: + lvl = active_task.logging_level + if lvl is None or lvl <= record.levelno: + if lvl <= logging.DEBUG: + record.msg += f' [wid={self._id}' + items = active_task.log_context.items() + if len(items) > 0: + record.msg += ', ' + con_str = ', '.join(f'{k}={v}' for k, v in items) + record.msg += con_str + record.msg += ']' + tid = active_task.comp_task_id + try: + serial = pickle.dumps(record) + except (pickle.PicklingError, TypeError): + serial = pickle.dumps(( + record.name, + record.levelno, + record.getMessage(), + )) + self._conn.send((RuntimeMessage.LOG, (tid, serial))) return record logging.setLogRecordFactory(record_factory) + # Start incoming thread + self.incoming_thread = Thread(target=self.recv_incoming) + self.incoming_thread.daemon = True + self.incoming_thread.start() + _logger.debug('Started incoming thread.') + # Communicate that this worker is ready self._conn.send((RuntimeMessage.STARTED, self._id)) def _loop(self) -> None: """Main worker event loop.""" - self._running = True while self._running: - self._try_step_next_ready_task() - self._try_idle() - self._handle_comms() - - def _try_idle(self) -> None: - """If there is nothing to do, wait until we receive a message.""" - empty_outgoing = len(self._outgoing) == 0 - no_ready_tasks = self._ready_task_ids.empty() - no_delayed_tasks = len(self._delayed_tasks) == 0 - - if empty_outgoing and no_ready_tasks and no_delayed_tasks: - self._conn.send((RuntimeMessage.WAITING, 1)) - wait([self._conn]) - - def _handle_comms(self) -> None: - """Handle all incoming and outgoing messages.""" - - # Handle outgoing communication - for out_msg in self._outgoing: - self._conn.send(out_msg) - self._outgoing.clear() - - # Handle incomming communication - while self._conn.poll(): - msg, payload = self._conn.recv() + try: + self._try_step_next_ready_task() + except Exception: + self._running = False + exc_info = sys.exc_info() + error_str = ''.join(traceback.format_exception(*exc_info)) + _logger.error(error_str) + try: + self._conn.send((RuntimeMessage.ERROR, error_str)) + except Exception: + pass + + def recv_incoming(self) -> None: + """Continuously receive all incoming messages.""" + while self._running: + # Receive message + try: + msg, payload = self._conn.recv() + except Exception: + _logger.debug('Crashed due to lost connection') + if sys.platform == 'win32': + os.kill(os.getpid(), 9) + else: + os.kill(os.getpid(), signal.SIGKILL) + exit() + + _logger.debug(f'Received message {msg.name}.') + _logger.log(1, f'Payload: {payload}') # Process message if msg == RuntimeMessage.SHUTDOWN: - self._running = False - return + if sys.platform == 'win32': + os.kill(os.getpid(), 9) + else: + os.kill(os.getpid(), signal.SIGKILL) elif msg == RuntimeMessage.SUBMIT: + self.read_receipt_mutex.acquire() task = cast(RuntimeTask, payload) + self.most_recent_read_submit = task.unique_id self._add_task(task) + self.read_receipt_mutex.release() elif msg == RuntimeMessage.SUBMIT_BATCH: + self.read_receipt_mutex.acquire() tasks = cast(List[RuntimeTask], payload) + self.most_recent_read_submit = tasks[0].unique_id self._add_task(tasks.pop()) # Submit one task self._delayed_tasks.extend(tasks) # Delay rest - # Delayed tasks have no context and are stored (more-or-less) - # as a function pointer together with the arguments. - # When it gets started, it consumes much more memory, - # so we delay the task start until necessary (at no cost) + self.read_receipt_mutex.release() elif msg == RuntimeMessage.RESULT: result = cast(RuntimeResult, payload) @@ -281,6 +309,17 @@ def _handle_comms(self) -> None: elif msg == RuntimeMessage.CANCEL: addr = cast(RuntimeAddress, payload) self._handle_cancel(addr) + # TODO: preempt? + + elif msg == RuntimeMessage.COMMUNICATE: + addrs, msg = cast(tuple[list[RuntimeAddress], Any], payload) + self._handle_communicate(addrs, msg) + + elif msg == RuntimeMessage.IMPORTPATH: + paths = cast(List[str], payload) + for path in paths: + if path not in sys.path: + sys.path.append(path) def _add_task(self, task: RuntimeTask) -> None: """Start a task and add it to the loop.""" @@ -290,8 +329,9 @@ def _add_task(self, task: RuntimeTask) -> None: def _handle_result(self, result: RuntimeResult) -> None: """Insert result into appropriate mailbox and wake waiting task.""" - mailbox_id = result.return_address.mailbox_index assert result.return_address.worker_id == self._id + + mailbox_id = result.return_address.mailbox_index if mailbox_id not in self._mailboxes: # If the mailbox has been dropped due to a cancel, ignore result return @@ -304,7 +344,11 @@ def _handle_result(self, result: RuntimeResult) -> None: task = self._tasks[box.dest_addr] if task.wake_on_next or box.ready: + # print(f'Worker {self._id} is waking task + # {task.return_address}, with {task.wake_on_next=}, + # {box.ready=}') self._ready_task_ids.put(box.dest_addr) # Wake it + box.dest_addr = None # Prevent double wake def _handle_cancel(self, addr: RuntimeAddress) -> None: """ @@ -315,16 +359,21 @@ def _handle_cancel(self, addr: RuntimeAddress) -> None: to discard cancelled tasks when popping from it. Therefore, we do not do anything with `self._ready_task_ids` here. + We also must make sure to call the `cancel` function of the + tasks to make sure their coroutines are cleaned up. + Also, we also don't need to send out cancel messages for cancelled children tasks since other workers can evaluate that for themselves using breadcrumbs and the original `addr` cancel message. """ + # TODO: Send update message? self._cancelled_task_ids.add(addr) # Remove all tasks that are children of `addr` from initialized tasks for key, task in self._tasks.items(): if task.is_descendant_of(addr): + task.cancel() for mailbox_id in self._tasks[key].owned_mailboxes: self._mailboxes.pop(mailbox_id) self._tasks = { @@ -338,16 +387,48 @@ def _handle_cancel(self, addr: RuntimeAddress) -> None: if not t.is_descendant_of(addr) ] + def _handle_communicate( + self, + addrs: list[RuntimeAddress], + msg: Any, + ) -> None: + for task_addr in addrs: + if task_addr not in self._tasks: + continue + + self._tasks[task_addr].msg_buffer.append(msg) + def _get_next_ready_task(self) -> RuntimeTask | None: - """Return the next ready task if one exists, otherwise None.""" + """Return the next ready task if one exists, otherwise block.""" while True: - if self._ready_task_ids.empty(): - if len(self._delayed_tasks) > 0: - self._add_task(self._delayed_tasks.pop()) - continue - return None + if self._ready_task_ids.empty() and len(self._delayed_tasks) > 0: + self._add_task(self._delayed_tasks.pop()) + continue - addr = self._ready_task_ids.get() + # Critical section + # Attempt to get a ready task. If none are available, message + # the manager/server with a waiting message letting them + # know the worker is idle. This needs to be atomic to prevent + # the self.more_recent_read_submit from being updated after + # catching the Empty exception, but before forming the payload. + self.read_receipt_mutex.acquire() + try: + addr = self._ready_task_ids.get_nowait() + + except Empty: + payload = (1, self.most_recent_read_submit) + self._conn.send((RuntimeMessage.WAITING, payload)) + self.read_receipt_mutex.release() + # Block for new message. Can release lock here since the + # the `self.most_recent_read_submit` has been used. + addr = self._ready_task_ids.get() + + else: + self.read_receipt_mutex.release() + + # Handle a shutdown request that occured while waiting + if not self._running: + return None if addr in self._cancelled_task_ids or addr not in self._tasks: # When a task is cancelled on the worker it is not removed @@ -362,6 +443,7 @@ def _get_next_ready_task(self) -> RuntimeTask | None: # then discard this one too. Each breadcrumb (bcb) is a # task address (unique system-wide task id) of an ancestor # task. + # TODO: do I need to manually remove addr from self._tasks? continue return task @@ -371,7 +453,6 @@ def _try_step_next_ready_task(self) -> None: task = self._get_next_ready_task() if task is None: - # Nothing to do return try: @@ -392,7 +473,7 @@ def _try_step_next_ready_task(self) -> None: exc_info = sys.exc_info() error_str = ''.join(traceback.format_exception(*exc_info)) error_payload = (self._active_task.comp_task_id, error_str) - self._outgoing.append((RuntimeMessage.ERROR, error_payload)) + self._conn.send((RuntimeMessage.ERROR, error_payload)) finally: self._active_task = None @@ -411,14 +492,17 @@ def _process_await(self, task: RuntimeTask, future: RuntimeFuture) -> None: box.dest_addr = task.return_address task.desired_box_id = future.mailbox_id - if future._next_flag: - # Set from Worker.next, implies the task wants the next result - if box.ready: - m = 'Cannot wait for next results on a complete task.' - raise RuntimeError(m) - task.wake_on_next = True - - elif box.ready: + # if future._next_flag: + # # Set from Worker.next, implies the task wants the next result + # # if box.ready: + # # m = 'Cannot wait for next results on a complete task.' + # # raise RuntimeError(m) + # task.wake_on_next = True + task.wake_on_next = future._next_flag + # print(f'Worker {self._id} is waiting on task + # {task.return_address}, with {task.wake_on_next=}') + + if box.ready: self._ready_task_ids.put(task.return_address) def _process_task_completion(self, task: RuntimeTask, result: Any) -> None: @@ -426,13 +510,18 @@ def _process_task_completion(self, task: RuntimeTask, result: Any) -> None: assert task is self._active_task packaged_result = RuntimeResult(task.return_address, result, self._id) + if task.return_address not in self._tasks: + # print(f'Task was cancelled: {task.return_address}, + # {task.fnargs[0].__name__}') + return + if task.return_address.worker_id == self._id: self._handle_result(packaged_result) - self._outgoing.append((RuntimeMessage.UPDATE, -1)) + self._conn.send((RuntimeMessage.UPDATE, -1)) # Let manager know this worker has one less task # without sending a result else: - self._outgoing.append((RuntimeMessage.RESULT, packaged_result)) + self._conn.send((RuntimeMessage.RESULT, packaged_result)) # Remove task self._tasks.pop(task.return_address) @@ -448,10 +537,6 @@ def _process_task_completion(self, task: RuntimeTask, result: Any) -> None: # Otherwise send a cancel message self.cancel(RuntimeFuture(mailbox_id)) - # Start delayed task - if self._ready_task_ids.empty() and len(self._delayed_tasks) > 0: - self._add_task(self._delayed_tasks.pop()) - def _get_desired_result(self, task: RuntimeTask) -> Any: """Retrieve the task's desired result from the mailboxes.""" if task.desired_box_id is None: @@ -461,7 +546,7 @@ def _get_desired_result(self, task: RuntimeTask) -> Any: if task.wake_on_next: fresh_results = box.get_new_results() - assert len(fresh_results) > 0 + # assert len(fresh_results) > 0 return fresh_results assert box.ready @@ -478,10 +563,25 @@ def submit( self, fn: Callable[..., Any], *args: Any, + task_name: str | None = None, + log_context: dict[str, str] = {}, **kwargs: Any, ) -> RuntimeFuture: """Submit `fn` as a task to the runtime.""" assert self._active_task is not None + + if task_name is not None and not isinstance(task_name, str): + raise RuntimeError('task_name must be a string.') + + if not isinstance(log_context, dict): + raise RuntimeError('log_context must be a dictionary.') + + for k, v in log_context.items(): + if not isinstance(k, str) or not isinstance(v, str): + raise RuntimeError( + 'log_context must be a map from strings to strings.', + ) + # Group fnargs together fnarg = (fn, args, kwargs) @@ -495,13 +595,16 @@ def submit( fnarg, RuntimeAddress(self._id, mailbox_id, 0), self._active_task.comp_task_id, - self._active_task.breadcrumbs + (self._active_task.return_address,), + self._active_task.breadcrumbs + + (self._active_task.return_address,), self._active_task.logging_level, self._active_task.max_logging_depth, + task_name, + {**self._active_task.log_context, **log_context}, ) # Submit the task (on the next cycle) - self._outgoing.append((RuntimeMessage.SUBMIT, task)) + self._conn.send((RuntimeMessage.SUBMIT, task)) # Return future pointing to the mailbox return RuntimeFuture(mailbox_id) @@ -510,10 +613,38 @@ def map( self, fn: Callable[..., Any], *args: Any, + task_name: Sequence[str | None] | str | None = None, + log_context: Sequence[dict[str, str]] | dict[str, str] = {}, **kwargs: Any, ) -> RuntimeFuture: """Map `fn` over the input arguments distributed across the runtime.""" assert self._active_task is not None + + if task_name is None or isinstance(task_name, str): + task_name = [task_name] * len(args[0]) + + if len(task_name) != len(args[0]): + raise RuntimeError( + 'task_name must be a string or a list of strings equal' + 'in length to the number of tasks.', + ) + + if isinstance(log_context, dict): + log_context = [log_context] * len(args[0]) + + if len(log_context) != len(args[0]): + raise RuntimeError( + 'log_context must be a dictionary or a list of dictionaries' + ' equal in length to the number of tasks.', + ) + + for context in log_context: + for k, v in context.items(): + if not isinstance(k, str) or not isinstance(v, str): + raise RuntimeError( + 'log_context must be a map from strings to strings.', + ) + # Group fnargs together fnargs = [] if len(args) == 1: @@ -543,16 +674,37 @@ def map( breadcrumbs, self._active_task.logging_level, self._active_task.max_logging_depth, + task_name[i], + {**self._active_task.log_context, **log_context[i]}, ) for i, fnarg in enumerate(fnargs) ] # Submit the tasks - self._outgoing.append((RuntimeMessage.SUBMIT_BATCH, tasks)) + self._conn.send((RuntimeMessage.SUBMIT_BATCH, tasks)) # Return future pointing to the mailbox return RuntimeFuture(mailbox_id) + def communicate(self, future: RuntimeFuture, msg: Any) -> None: + """Send a message to the task associated with `future`.""" + assert self._active_task is not None + assert future.mailbox_id in self._mailboxes + + num_slots = self._mailboxes[future.mailbox_id].expected_num_results + addrs = [ + RuntimeAddress(self._id, future.mailbox_id, slot_id) + for slot_id in range(num_slots) + ] + self._conn.send((RuntimeMessage.COMMUNICATE, (addrs, msg))) + + def get_messages(self) -> list[Any]: + """Return all messages received by the worker for this task.""" + assert self._active_task is not None + x = self._active_task.msg_buffer + self._active_task.msg_buffer = [] + return x + def cancel(self, future: RuntimeFuture) -> None: """Cancel all tasks associated with `future`.""" assert self._active_task is not None @@ -563,8 +715,8 @@ def cancel(self, future: RuntimeFuture) -> None: RuntimeAddress(self._id, future.mailbox_id, slot_id) for slot_id in range(num_slots) ] - msgs = [(RuntimeMessage.CANCEL, addr) for addr in addrs] - self._outgoing.extend(msgs) + for addr in addrs: + self._conn.send((RuntimeMessage.CANCEL, addr)) def get_cache(self) -> dict[str, Any]: """ @@ -591,7 +743,8 @@ async def next(self, future: RuntimeFuture) -> list[tuple[int, Any]]: returned. Each result is paired with the index of its arguments in the original map call. """ - if future._done: + # if future._done: + if future.mailbox_id not in self._mailboxes: raise RuntimeError('Cannot wait on an already completed result.') future._next_flag = True @@ -604,25 +757,35 @@ async def next(self, future: RuntimeFuture) -> list[tuple[int, Any]]: _worker = None -def start_worker(w_id: int | None, port: int, cpu: int | None = None) -> None: +def start_worker( + w_id: int | None, + port: int, + cpu: int | None = None, + logging_level: int = logging.WARNING, + num_blas_threads: int = 1, + log_client: bool = False, +) -> None: """Start this process's worker.""" if w_id is not None: # Ignore interrupt signals on workers, boss will handle it for us # If w_id is None, then we are being spawned separately. signal.signal(signal.SIGINT, signal.SIG_IGN) + # TODO: check what needs to be done on win - # Purge all standard python logging configurations - for _, logger in logging.Logger.manager.loggerDict.items(): - if isinstance(logger, logging.PlaceHolder): - continue - logger.handlers.clear() - logging.Logger.manager.loggerDict = {} + # Set number of BLAS threads + set_blas_thread_counts(num_blas_threads) + # Enforce no default logging + logging.lastResort = logging.NullHandler() + logging.getLogger().handlers.clear() + + # Pin worker to cpu if cpu is not None: if sys.platform == 'win32': raise RuntimeError('Cannot pin worker to cpu on windows.') os.sched_setaffinity(0, [cpu]) + # Connect to manager max_retries = 7 wait_time = .1 conn: Connection | None = None @@ -639,10 +802,29 @@ def start_worker(w_id: int | None, port: int, cpu: int | None = None) -> None: if conn is None: raise RuntimeError('Unable to establish connection with manager.') + # If id isn't provided, wait for assignment if w_id is None: msg, w_id = conn.recv() + assert isinstance(w_id, int) assert msg == RuntimeMessage.STARTED + # Set up runtime logging + if not log_client: + _runtime_logger = logging.getLogger('bqskit.runtime') + else: + _runtime_logger = logging.getLogger() + _runtime_logger.propagate = False + _runtime_logger.setLevel(logging_level) + _handler = logging.StreamHandler() + _handler.setLevel(0) + _fmt_header = '%(asctime)s.%(msecs)03d - %(levelname)-8s |' + _fmt_message = f' [wid={w_id}]: %(message)s' + _fmt = _fmt_header + _fmt_message + _formatter = logging.Formatter(_fmt, '%H:%M:%S') + _handler.setFormatter(_formatter) + _runtime_logger.addHandler(_handler) + + # Build and start worker global _worker _worker = Worker(w_id, conn) _worker._loop() @@ -690,6 +872,23 @@ def start_worker_rank() -> None: default=default_worker_port, help='The port the workers will try to connect to a manager on.', ) + parser.add_argument( + '-v', '--verbose', + action='count', + default=0, + help='Enable logging of increasing verbosity, either -v, -vv, or -vvv.', + ) + parser.add_argument( + '-l', '--log-client', + action='store_true', + help='Log messages from the client process.', + ) + parser.add_argument( + '-t', '--num_blas_threads', + type=int, + default=1, + help='The number of threads to use in BLAS libraries.', + ) args = parser.parse_args() if args.cpus is not None: @@ -709,10 +908,23 @@ def start_worker_rank() -> None: else: cpus = [None for _ in range(args.num_workers)] + logging_level = [30, 20, 10, 1][min(args.verbose, 3)] + + if args.log_client and logging_level > 10: + raise RuntimeError('Cannot log client messages without at least -vv.') + # Spawn worker process procs = [] for cpu in cpus: - procs.append(Process(target=start_worker, args=(None, args.port, cpu))) + pargs = ( + None, + args.port, + cpu, + logging_level, + args.num_blas_threads, + args.log_client, + ) + procs.append(Process(target=start_worker, args=pargs)) procs[-1].start() # Join them diff --git a/bqskit/utils/cachedclass.py b/bqskit/utils/cachedclass.py index 751a361b6..7edac47b9 100644 --- a/bqskit/utils/cachedclass.py +++ b/bqskit/utils/cachedclass.py @@ -63,7 +63,8 @@ def __new__(cls: type[T], *args: Any, **kwargs: Any) -> T: _instances = cls._instances # type: ignore if _instances.get(key, None) is None: - _logger.debug( + _logger.log( + 1, ( 'Creating cached instance for class: %s,' ' with args %s, and kwargs %s' diff --git a/bqskit/utils/math.py b/bqskit/utils/math.py index a95d76518..f2560044d 100644 --- a/bqskit/utils/math.py +++ b/bqskit/utils/math.py @@ -8,6 +8,7 @@ import scipy as sp from bqskit.qis.pauli import PauliMatrices +from bqskit.qis.pauliz import PauliZMatrices from bqskit.qis.unitary.unitary import RealVector @@ -19,7 +20,7 @@ def dexpmv( User must provide M and its derivative dM. If the argument dM is a vector of partials then dF will be the respective partial vector. - This is done using a Pade Approximat with scaling and squaring. + This is done using a Pade Approximation with scaling and squaring. Args: M (np.ndarray): Matrix to exponentiate. @@ -168,6 +169,36 @@ def pauli_expansion(H: npt.NDArray[np.complex128]) -> npt.NDArray[np.float64]: return np.array(X) +def pauliz_expansion(H: npt.NDArray[np.complex128]) -> npt.NDArray[np.float64]: + """ + Computes a Pauli Z expansion of the diagonal hermitian matrix H. + + Args: + H (np.ndarray): The (N, N) diagonal hermitian matrix to expand. + + Returns: + np.ndarray: The coefficients of a Pauli Z expansion for H, + i.e., x dot Sigma = H where Sigma contains Pauli Z matrices of + same size of H. + + Note: + This assumes the input is diagonal but of shape (N, N). No check is + done for hermicity. The output is undefined on non-hermitian inputs. + """ + diag_H = np.diag(np.diag(H)) + if not np.allclose(H, diag_H): + msg = 'H must be a diagonal matrix.' + raise ValueError(msg) + # Change basis of H to Pauli Basis (solve for coefficients -> X) + n = int(np.log2(len(H))) + paulizs = PauliZMatrices(n) + flatten_paulizs = [np.diag(pauli) for pauli in paulizs] + flatten_H = np.diag(H) + A = np.stack(flatten_paulizs, axis=-1) + X = np.real(np.matmul(np.linalg.inv(A), flatten_H)) + return np.array(X) + + def compute_su_generators(n: int) -> npt.NDArray[np.complex128]: """ Computes the Lie algebra generators for SU(n). @@ -257,3 +288,27 @@ def canonical_unitary( correction_phase = 0 - std_phase std_correction = np.exp(1j * correction_phase) return std_correction * special_unitary + + +def diagonal_distance(unitary: npt.NDArray[np.complex128]) -> float: + """ + Compute how diagonal a unitary is. + + The diagonal distance measures how closely a unitary can be approx- + imately inverted by a diagonal unitary. A unitary is approximately + inverted when the Hilbert-Schmidt distance to the identity is less + than some threshold. + + A proof of correctness can be found in the appendix of: + https://arxiv.org/abs/2409.00433 + + Args: + unitary (np.ndarray): The unitary matrix to check. + + Returns: + float: The Hilbert-Schmidt distance to the nearest diagonal. + """ + eps = unitary - np.diag(np.diag(unitary)) + eps2 = eps * eps.conj() + distance = abs(np.sqrt(eps2.sum(-1).max())) + return distance diff --git a/bqskit/utils/test/strategies.py b/bqskit/utils/test/strategies.py index de0e5e2dc..420ffc52a 100644 --- a/bqskit/utils/test/strategies.py +++ b/bqskit/utils/test/strategies.py @@ -40,6 +40,7 @@ from bqskit.ir.region import CircuitRegion from bqskit.qis.state.state import StateLike from bqskit.qis.state.state import StateVector +from bqskit.qis.unitary import RealVector from bqskit.qis.unitary import UnitaryMatrix from bqskit.qis.unitary.unitarymatrix import UnitaryLike from bqskit.utils.typing import is_integer @@ -258,6 +259,24 @@ def gates( return gate +@composite +def gates_and_params( + draw: Any, + radixes: Sequence[int] | int | None = None, + constant: bool | None = None, +) -> tuple[Gate, RealVector]: + """Hypothesis strategy for generating gates and parameters.""" + gate = draw(gates(radixes, constant)) + params = draw( + lists( + floats(allow_nan=False, allow_infinity=False, width=16), + min_size=gate.num_params, + max_size=gate.num_params, + ), + ) + return gate, params + + @composite def operations( draw: Any, diff --git a/docs/conf.py b/docs/conf.py index cfd238319..45036da82 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -38,6 +38,7 @@ 'myst_parser', 'jupyter_sphinx', 'nbsphinx', + 'sphinx_autodoc_typehints', ] # Add any paths that contain templates here, relative to this directory. @@ -100,7 +101,7 @@ 'pytket', 'cirq', 'qutip', - 'qiskit', + 'dill', ] nbsphinx_allow_errors = True nbsphinx_execute = 'never' diff --git a/docs/guides/customgate.md b/docs/guides/customgate.md new file mode 100644 index 000000000..eca66e8cd --- /dev/null +++ b/docs/guides/customgate.md @@ -0,0 +1,188 @@ +# Implement a Custom Gate + +BQSKit's claims great portability, and as such, most algorithms in BQSKit can +work natively with any gate set. We have included many commonly used gates +inside of the [`bqskit.ir.gates`](https://bqskit.readthedocs.io/en/latest/source/ir.html#module-bqskit.ir.gates) +subpackage, but you may want to experiment with your own gates. In this tutorial, +we will implement a custom gate in BQSKit. Since BQSKit's algorithms are built +on numerical instantiation, this process is usually as simple as defining a new +subclass with a unitary at a high-level. + +For example, let's look at the [`TGate`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.gates.TGate.html#bqskit.ir.gates.TGate) definition in BQSKit: + +```python +... +class TGate(ConstantGate, QubitGate): + _num_qudits = 1 + _qasm_name = 't' + _utry = UnitaryMatrix( + [ + [1, 0], + [0, cmath.exp(1j * cmath.pi / 4)], + ], + ) +``` + +A gate is defined by subclassing [`Gate`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.Gate.html#bqskit.ir.Gate), +however, there are some abstract subclasses that can be extended instead to simplify the process. For example, the [`TGate`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.gates.TGate.html#bqskit.ir.gates.TGate) is a subclass of +[`ConstantGate`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.gates.ConstantGate.html#bqskit.ir.gates.ConstantGate) and +[`QubitGate`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.gates.QubitGate.html#bqskit.ir.gates.QubitGate). The [`ConstantGate`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.gates.ConstantGate.html#bqskit.ir.gates.ConstantGate) +subclass is used for gates that have a fixed unitary matrix, and the [`QubitGate`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.gates.QubitGate.html#bqskit.ir.gates.QubitGate) subclass is used for gates that act only on qubits -- rather than qudits. In the following sections, the process of defining a custom gate will be explained in more detail. + +## Defining a Custom Gate + +To define a custom gate, you need to subclass [`Gate`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.Gate.html#bqskit.ir.Gate), and +define all the required attributes. These attributes can be defined as instance variables, class variables, or through methods. The following +attributes are required: + +- [`_num_qudits`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.Gate.num_qudits.html#bqskit.ir.Gate.num_qudits): The number of qudits the gate acts on. +- [`_num_params`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.Gate.num_params.html#bqskit.ir.Gate.num_params): The number of parameters the gate takes. +- [`_radixes`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.Gate.radixes.html#bqskit.ir.Gate.radixes): The radixes of the qudits this gate acts on. This is a tuple of integers, where each integer is the radix of the corresponding qudit. For example, `(2, 2)` would be a 2-qubit gate, `(3, 3)` would be a 2-qutrit gate, and `(2, 3, 3)` would be a gate that acts on a qubit and two qutrits. +- [`_name`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.Gate.name.html#bqskit.ir.Gate.name): The name of the gate. This is used during print operations. +- [`_qasm_name`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.Gate.qasm_name.html#bqskit.ir.Gate.qasm_name): The name of the gate in QASM. (Qubit only gates, should use lowercase, optional) + +Additionally, you will need to override the abstract method [`get_unitary`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.qis.Unitary.get_unitary.html#bqskit.qis.Unitary.get_unitary). This method maps the parameters of the gate to a unitary matrix. + +Here is an example of a custom gate that acts on a single qubit: + +```python +import cmath +from bqskit.ir.gate import Gate +from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix +from bqskit.qis.unitary.unitary import RealVector + +class MyGate(Gate): + _num_qudits = 1 + _num_params = 1 + _radixes = (2,) + _name = 'MyGate' + _qasm_name = 'mygate' + + def get_unitary(self, params: RealVector) -> UnitaryMatrix: + theta = params[0] + return UnitaryMatrix( + [ + [cmath.exp(1j * theta / 2), 0], + [0, cmath.exp(-1j * theta / 2)], + ], + ) +``` + +Note the `params` argument is a [`RealVector`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.qis.RealVector.html#bqskit.qis.RealVector) object, which is an alias for many types of float arrays. There is a helper method in the [`Gate`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.Gate.html#bqskit.ir.Gate) class hierarchy called [`check_parameters`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.qis.Unitary.check_parameters.html#bqskit.qis.Unitary.check_parameters) that can be used to validate the parameters before using them. This will check for the correct types and lengths of the parameters: + +```python +... + def get_unitary(self, params: RealVector) -> UnitaryMatrix: + self.check_parameters(params) + ... + return UnitaryMatrix( + ... + ) +``` + +As mentioned previously, the required attributes can be defined as class variables, like in the above example, or as instance variables. The following example shows how to define a tensor product of an arbitrary number of `MyGate`s using instance variables: + +```python +import cmath +from bqskit.ir.gate import Gate +from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix +from bqskit.qis.unitary.unitary import RealVector + +class MyGateTensor(Gate): + def __init__(self, num_qudits: int) -> None: + self._num_qudits = num_qudits + self._num_params = 1 + self._radixes = tuple([2] * num_qudits) + self._name = f'MyGateTensor{num_qudits}' + + def get_unitary(self, params: RealVector) -> UnitaryMatrix: + self.check_parameters(params) + theta = params[0] + base = UnitaryMatrix( + [ + [cmath.exp(1j * theta / 2), 0], + [0, cmath.exp(-1j * theta / 2)], + ], + ) + base.otimes(*[base] * (self._num_qudits - 1)) # base tensor product with itself + # Note: Since the unitary is diagonal, there are more efficient ways to + # compute the tensor product, but this is a simple example meant + # to demonstrate the concept. In general, you should always implement + # the most efficient method for your gate. +``` + +This style is helpful when the gate's attributes are dependent on the constructor arguments. + +The last way to define the attributes is through methods. The corresponding property names can be found on the [`Gate`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.Gate.html#bqskit.ir.Gate) class. The following example computes the gate name of `MyGateTensor` through the `name` property: + +```python +... +class MyGateTensor(Gate): + ... # __init__ and get_unitary methods same as before without _name attribute + + @property + def name(self) -> str: + return f'MyGateTensor{self._num_qudits}' + +``` + +## Utilizing Helper Classes + +BQSKit provides some helper classes to simplify the process of defining gates. In the first example of this guide, we used the [`ConstantGate`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.gates.ConstantGate.html#bqskit.ir.gates.ConstantGate) and [`QubitGate`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.gates.QubitGate.html#bqskit.ir.gates.QubitGate) helper classes. To use these helper subclasses, we will subclass them instead of [`Gate`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.Gate.html#bqskit.ir.Gate). The following are the available helper classes: + +- [`ConstantGate`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.gates.ConstantGate.html#bqskit.ir.gates.ConstantGate): A gate that has a fixed unitary matrix with no parameters. This will automatically set `_num_params` to 0, and swap the `get_unitary` method for a `_utry` attribute. Additionally, these gates have the trivial differentiable implementations provided. +- [`QubitGate`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.gates.QubitGate.html#bqskit.ir.gates.QubitGate): A gate that acts only on qubits. This defines `_radixes` to be all `2`s. +- [`QutritGate`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.gates.QutritGate.html#bqskit.ir.gates.QutritGate): A gate that acts only on qutrits. This defines `_radixes` to be all `3`s. +- [`QuditGate`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.gates.QuditGate.html#bqskit.ir.gates.QuditGate): A gate that acts on qudits of the same radix. This swaps the `_radixes` requirement for a required `_radix` attribute. This is useful for gates that act on qudits of the same radix, but not necessarily only qubits or qutrits. +- [`ComposedGate`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.gates.ComposedGate.html#bqskit.ir.gates.ComposedGate): A gate that is composed of other gates. This provides methods to dynamically determine if the gate is differentiable or optimizable via other means. + +## Differentiable Gates + +If you are implementing a parameterized gate, you may want to make it differentiable. By making a gate differentiable, you allow it to be used by BQSKit's instantiation engine. In turn, this allows synthesis and other algorithms to work more easily with these gates. To do this, you will need to additionally subclass [`DifferentiableUnitary`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.qis.DifferentiableUnitary.html) and implement the [`get_grad`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.qis.DifferentiableUnitary.get_grad.html#bqskit.qis.DifferentiableUnitary.get_grad) method. `ConstantGate`s are trivially differentiable, as they have no parameters. + +Most of the time, the [`get_unitary_and_grad`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.qis.DifferentiableUnitary.get_unitary_and_grad.html#bqskit.qis.DifferentiableUnitary.get_unitary_and_grad) method is called by other parts of BQSKit, since both the unitary and gradient are typically needed at the same time. For most gates, computing them at the same time can allow for greater efficiency, since the unitary and gradient can share some computations. + +Let's make `MyGate` differentiable: + +```python +import cmath +from bqskit.ir.gate import Gate +from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix +from bqskit.qis.unitary.unitary import RealVector +from bqskit.qis.unitary.differentiableunitary import DifferentiableUnitary + +class MyGate(Gate, DifferentiableUnitary): + _num_qudits = 1 + _num_params = 1 + _radixes = (2,) + _name = 'MyGate' + _qasm_name = 'mygate' + + def get_unitary(self, params: RealVector) -> UnitaryMatrix: + self.check_parameters(params) + theta = params[0] + return UnitaryMatrix( + [ + [cmath.exp(1j * theta / 2), 0], + [0, cmath.exp(-1j * theta / 2)], + ], + ) + + def get_grad(self, params: RealVector) -> npt.NDArray[np.complex128]: + self.check_parameters(params) + theta = params[0] + return np.array( + [ + [ + [1j / 2 * cmath.exp(1j * theta / 2), 0], + [0, -1j / 2 * cmath.exp(-1j * theta / 2)], + ], + ], + ) +``` + +The `get_grad` method should return a 3D array, where the first index is the parameter index. `get_grad(params)[i]` should return the gradient of the unitary with respect to the `i`-th parameter. The gradient should be a matrix of the same shape as the unitary matrix, where each element is the derivative of the unitary matrix element with respect to the parameter. + +## Working with QASM + +If you want to use your gate in QASM, you will need to define the `_qasm_name` attribute. This is the name of the gate in QASM. However, some gates require special qasm definitions to be included at the top of a qasm file. This can be achieved by defining the [`get_qasm_gate_def`](https://bqskit.readthedocs.io/en/latest/source/autogen/bqskit.ir.Gate.get_qasm_gate_def.html#bqskit.ir.Gate.get_qasm_gate_def) method. This method returns a string, which will be included as-is at the top of every qasm file that uses the gate. diff --git a/docs/guides/distributing.md b/docs/guides/distributing.md index 38966bd15..c79c34e72 100644 --- a/docs/guides/distributing.md +++ b/docs/guides/distributing.md @@ -1,4 +1,4 @@ -# Distributing BQSKit Across a Cluster +# Distribute BQSKit Across a Cluster This guide describes how to launch a BQSKit Runtime Server in detached mode on one or more computers, connect to it, and perform compilations on the server. diff --git a/docs/index.rst b/docs/index.rst index e3de9ed83..1b00e56b1 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -25,9 +25,10 @@ our `tutorial series. `_ :caption: Guides :maxdepth: 1 + guides/customgate.md + guides/custompass.md guides/distributing.md guides/usegpus.md - guides/custompass.md .. toctree:: :caption: API Reference diff --git a/docs/requirements.txt b/docs/requirements.txt index 45bb175e7..a88d4eaa7 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -2,6 +2,7 @@ Sphinx>=4.5.0 sphinx-autodoc-typehints>=1.12.0 sphinx-rtd-theme>=1.0.0 sphinx-togglebutton>=0.2.3 +sphinx-autodoc-typehints>=2.3.0 sphinxcontrib-applehelp>=1.0.2 sphinxcontrib-devhelp>=1.0.2 sphinxcontrib-htmlhelp>=2.0.0 diff --git a/setup.py b/setup.py index 005681ffd..8edcd43de 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ root_dir_path = os.path.abspath(os.path.dirname(__file__)) pkg_dir_path = os.path.join(root_dir_path, 'bqskit') readme_path = os.path.join(root_dir_path, 'README.md') -version_path = os.path.join(pkg_dir_path, 'version.py') +version_path = os.path.join(pkg_dir_path, '_version.py') # Load Version Number with open(version_path) as version_file: @@ -71,6 +71,7 @@ 'numpy>=1.22.0', 'scipy>=1.8.0', 'typing-extensions>=4.0.0', + 'dill>=0.3.8', ], python_requires='>=3.8, <4', entry_points={ diff --git a/tests/compiler/test_data.py b/tests/compiler/test_data.py index 075a95c82..934215db6 100644 --- a/tests/compiler/test_data.py +++ b/tests/compiler/test_data.py @@ -26,3 +26,9 @@ def test_update_error_mul() -> None: assert data.error == 0.75 data.update_error_mul(0.5) assert data.error == 0.875 + + +def test_target_doesnt_get_expanded_on_update() -> None: + data = PassData(Circuit(64)) + data2 = PassData(Circuit(64)) + data.update(data2) # Should not crash diff --git a/tests/compiler/test_gateset.py b/tests/compiler/test_gateset.py index 009003a95..4f89b8fca 100644 --- a/tests/compiler/test_gateset.py +++ b/tests/compiler/test_gateset.py @@ -522,3 +522,16 @@ def test_gate_set_repr() -> None: repr(gate_set) == 'GateSet({CNOTGate, U3Gate})' or repr(gate_set) == 'GateSet({U3Gate, CNOTGate})' ) + + +def test_gate_set_hash() -> None: + gate_set_1 = GateSet({CNOTGate(), U3Gate()}) + gate_set_2 = GateSet({U3Gate(), CNOTGate()}) + gate_set_3 = GateSet({U3Gate(), CNOTGate(), RZGate()}) + + h1 = hash(gate_set_1) + h2 = hash(gate_set_2) + h3 = hash(gate_set_3) + + assert h1 == h2 + assert h1 != h3 diff --git a/tests/compiler/test_registry.py b/tests/compiler/test_registry.py new file mode 100644 index 000000000..4cc6065bc --- /dev/null +++ b/tests/compiler/test_registry.py @@ -0,0 +1,152 @@ +"""This file tests the register_workflow function.""" +from __future__ import annotations + +from itertools import combinations +from random import choice + +import pytest +from numpy import allclose + +from bqskit.compiler.compile import compile +from bqskit.compiler.machine import MachineModel +from bqskit.compiler.registry import _compile_circuit_registry +from bqskit.compiler.registry import _compile_statemap_registry +from bqskit.compiler.registry import _compile_stateprep_registry +from bqskit.compiler.registry import _compile_unitary_registry +from bqskit.compiler.registry import register_workflow +from bqskit.compiler.workflow import Workflow +from bqskit.compiler.workflow import WorkflowLike +from bqskit.ir import Circuit +from bqskit.ir import Gate +from bqskit.ir.gates import CZGate +from bqskit.ir.gates import HGate +from bqskit.ir.gates import RZGate +from bqskit.ir.gates import U3Gate +from bqskit.passes import QSearchSynthesisPass +from bqskit.passes import QuickPartitioner +from bqskit.passes import ScanningGateRemovalPass + + +def machine_match(mach_a: MachineModel, mach_b: MachineModel) -> bool: + if mach_a.num_qudits != mach_b.num_qudits: + return False + if mach_a.radixes != mach_b.radixes: + return False + if mach_a.coupling_graph != mach_b.coupling_graph: + return False + if mach_a.gate_set != mach_b.gate_set: + return False + return True + + +def unitary_match(unit_a: Circuit, unit_b: Circuit) -> bool: + return allclose(unit_a.get_unitary(), unit_b.get_unitary(), atol=1e-5) + + +def workflow_match( + workflow_a: WorkflowLike, + workflow_b: WorkflowLike, +) -> bool: + if not isinstance(workflow_a, Workflow): + workflow_a = Workflow(workflow_a) + if not isinstance(workflow_b, Workflow): + workflow_b = Workflow(workflow_b) + if len(workflow_a) != len(workflow_b): + return False + for a, b in zip(workflow_a, workflow_b): + if a.name != b.name: + return False + return True + + +def simple_circuit(num_qudits: int, gate_set: list[Gate]) -> Circuit: + circ = Circuit(num_qudits) + gate = choice(gate_set) + if gate.num_qudits == 1: + loc = choice(range(num_qudits)) + else: + loc = choice(list(combinations(range(num_qudits), 2))) # type: ignore + gate_inv = gate.get_inverse() + circ.append_gate(gate, loc) + circ.append_gate(gate_inv, loc) + return circ + + +class TestRegisterWorkflow: + + @pytest.fixture(autouse=True) + def setup(self) -> None: + # global _compile_registry + _compile_circuit_registry.clear() + _compile_unitary_registry.clear() + _compile_statemap_registry.clear() + _compile_stateprep_registry.clear() + + def test_register_workflow(self) -> None: + global _compile_circuit_registry + global _compile_unitary_registry + global _compile_statemap_registry + global _compile_stateprep_registry + assert _compile_circuit_registry == {} + assert _compile_unitary_registry == {} + assert _compile_statemap_registry == {} + assert _compile_stateprep_registry == {} + gateset = [CZGate(), HGate(), RZGate()] + num_qudits = 3 + machine = MachineModel(num_qudits, gate_set=gateset) + circuit_workflow = [QuickPartitioner(), ScanningGateRemovalPass()] + other_workflow = [QuickPartitioner(), QSearchSynthesisPass()] + register_workflow(machine, circuit_workflow, 1, 'circuit') + register_workflow(machine, other_workflow, 1, 'unitary') + register_workflow(machine, other_workflow, 1, 'statemap') + register_workflow(machine, other_workflow, 1, 'stateprep') + assert machine in _compile_circuit_registry + assert 1 in _compile_circuit_registry[machine] + assert workflow_match( + _compile_circuit_registry[machine][1], circuit_workflow, + ) + assert machine in _compile_unitary_registry + assert 1 in _compile_unitary_registry[machine] + assert workflow_match( + _compile_unitary_registry[machine][1], other_workflow, + ) + assert machine in _compile_statemap_registry + assert 1 in _compile_statemap_registry[machine] + assert workflow_match( + _compile_statemap_registry[machine][1], other_workflow, + ) + assert machine in _compile_stateprep_registry + assert 1 in _compile_stateprep_registry[machine] + assert workflow_match( + _compile_stateprep_registry[machine][1], other_workflow, + ) + + def test_custom_compile_machine(self) -> None: + global _compile_circuit_registry + assert _compile_circuit_registry == {} + gateset = [CZGate(), HGate(), RZGate()] + num_qudits = 3 + machine = MachineModel(num_qudits, gate_set=gateset) + workflow = [QuickPartitioner(2)] + register_workflow(machine, workflow, 1, 'circuit') + circuit = simple_circuit(num_qudits, gateset) + result = compile(circuit, machine) + assert unitary_match(result, circuit) + assert result.num_operations > 0 + assert result.gate_counts != circuit.gate_counts + result.unfold_all() + assert result.gate_counts == circuit.gate_counts + + def test_custom_opt_level(self) -> None: + global _compile_circuit_registry + assert _compile_circuit_registry == {} + gateset = [CZGate(), HGate(), RZGate()] + num_qudits = 3 + machine = MachineModel(num_qudits, gate_set=gateset) + workflow = [QSearchSynthesisPass()] + register_workflow(machine, workflow, 2, 'circuit') + circuit = simple_circuit(num_qudits, gateset) + result = compile(circuit, machine, optimization_level=2) + assert unitary_match(result, circuit) + assert result.gate_counts != circuit.gate_counts + assert U3Gate() in result.gate_set diff --git a/tests/ir/circuit/test_region_methods.py b/tests/ir/circuit/test_region_methods.py index 18c34c86c..f6629fd49 100644 --- a/tests/ir/circuit/test_region_methods.py +++ b/tests/ir/circuit/test_region_methods.py @@ -13,6 +13,7 @@ from bqskit.ir.gates.constant.h import HGate from bqskit.ir.gates.constant.x import XGate from bqskit.ir.gates.parameterized.u3 import U3Gate +from bqskit.ir.location import CircuitLocation from bqskit.ir.point import CircuitPoint from bqskit.ir.point import CircuitPointLike from bqskit.ir.region import CircuitRegion @@ -296,7 +297,7 @@ def test_small_circuit_2(self) -> None: circuit.append_gate(HGate(), 1) circuit.append_gate(HGate(), 2) region = circuit.surround((0, 1), 2) - assert region == CircuitRegion({0: (0, 1), 1: (0, 2)}) + assert region == CircuitRegion({0: (0, 1), 1: (0, 3)}) def test_small_circuit_3(self) -> None: circuit = Circuit(3) @@ -324,7 +325,7 @@ def test_through_middle_of_outside(self) -> None: circuit.append_gate(CNOTGate(), (0, 2)) circuit.append_gate(CNOTGate(), (0, 1)) region = circuit.surround((1, 0), 2) - assert region == CircuitRegion({0: (0, 1), 1: (0, 1)}) + assert region == CircuitRegion({0: (0, 1), 1: (0, 2)}) def test_with_fold(self, r6_qudit_circuit: Circuit) -> None: cycle = 0 @@ -338,3 +339,88 @@ def test_with_fold(self, r6_qudit_circuit: Circuit) -> None: region = r6_qudit_circuit.surround((cycle, qudit), 4) r6_qudit_circuit.fold(region) assert r6_qudit_circuit.get_unitary() == utry + + def test_surround_symmetric(self) -> None: + circuit = Circuit(6) + # whole wall of even + circuit.append_gate(CNOTGate(), [0, 1]) + circuit.append_gate(CNOTGate(), [2, 3]) + circuit.append_gate(CNOTGate(), [4, 5]) + + # one odd gate; problematic point in test + circuit.append_gate(CNOTGate(), [3, 4]) + + # whole wall of even + circuit.append_gate(CNOTGate(), [0, 1]) + circuit.append_gate(CNOTGate(), [2, 3]) + circuit.append_gate(CNOTGate(), [4, 5]) + + region = circuit.surround((1, 3), 4) + assert region.location == CircuitLocation([2, 3, 4, 5]) + + def test_surround_filter_hard(self) -> None: + circuit = Circuit(7) + # whole wall of even + circuit.append_gate(CNOTGate(), [0, 1]) + circuit.append_gate(CNOTGate(), [2, 3]) + circuit.append_gate(CNOTGate(), [4, 5]) + + # one odd gate; problematic point in test + circuit.append_gate(CNOTGate(), [3, 4]) + + # whole wall of even + circuit.append_gate(CNOTGate(), [0, 1]) + circuit.append_gate(CNOTGate(), [2, 3]) + circuit.append_gate(CNOTGate(), [4, 5]) + + # more odd gates to really test filter + circuit.append_gate(CNOTGate(), [5, 6]) + circuit.append_gate(CNOTGate(), [5, 6]) + circuit.append_gate(CNOTGate(), [5, 6]) + circuit.append_gate(CNOTGate(), [5, 6]) + circuit.append_gate(CNOTGate(), [5, 6]) + + region = circuit.surround( + (1, 3), 4, None, None, lambda region: ( + region.min_qudit > 1 and region.max_qudit < 6 + ), + ) + assert region.location == CircuitLocation([2, 3, 4, 5]) + + def test_surround_filter_topology(self) -> None: + circuit = Circuit(5) + circuit.append_gate(CNOTGate(), [0, 1]) + circuit.append_gate(CNOTGate(), [0, 2]) + circuit.append_gate(CNOTGate(), [0, 1]) + circuit.append_gate(CNOTGate(), [0, 2]) + circuit.append_gate(CNOTGate(), [1, 2]) + circuit.append_gate(CNOTGate(), [2, 3]) + circuit.append_gate(CNOTGate(), [3, 4]) + + def region_filter(region: CircuitRegion) -> bool: + return circuit.get_slice(region.points).coupling_graph.is_linear() + + region = circuit.surround( + (4, 1), 4, None, None, lambda region: ( + region_filter(region) + ), + ) + assert circuit.is_valid_region(region) + assert region.location == CircuitLocation([1, 2, 3, 4]) + + +def test_check_region_1() -> None: + c = Circuit(4) + c.append_gate(CNOTGate(), [1, 2]) + c.append_gate(CNOTGate(), [0, 1]) + c.append_gate(CNOTGate(), [2, 3]) + c.append_gate(CNOTGate(), [1, 2]) + assert not c.is_valid_region({1: (0, 2), 2: (0, 2), 3: (0, 2)}) + + +def test_check_region_2() -> None: + c = Circuit(3) + c.append_gate(CNOTGate(), [0, 1]) + c.append_gate(CNOTGate(), [0, 2]) + c.append_gate(CNOTGate(), [1, 2]) + assert not c.is_valid_region({0: (0, 0), 1: (0, 2), 2: (2, 2)}) diff --git a/tests/ir/gates/composed/test_power.py b/tests/ir/gates/composed/test_power.py new file mode 100644 index 000000000..82f5c3634 --- /dev/null +++ b/tests/ir/gates/composed/test_power.py @@ -0,0 +1,58 @@ +# type: ignore +"""This module tests the PowerGate class.""" +from __future__ import annotations + +import numpy as np +import numpy.typing as npt +from hypothesis import given +from hypothesis.strategies import integers + +from bqskit.ir.gate import Gate +from bqskit.ir.gates import PowerGate +from bqskit.qis.unitary.differentiable import DifferentiableUnitary +from bqskit.qis.unitary.unitary import RealVector +from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix +from bqskit.utils.test.strategies import gates_and_params + + +def _recursively_calc_power_grad( + g: UnitaryMatrix, + dg: npt.NDArray[np.complex128], + power: int, +) -> npt.NDArray[np.complex128]: + """D(g^n+1) = d(g@g^n) = g @ d(g^n) + dg @ g^n.""" + if len(dg) == 0 or power == 0: + return np.zeros_like(dg) + if power < 0: + return _recursively_calc_power_grad( + g.dagger, + dg.conj().transpose([0, 2, 1]), + -power, + ) + if power == 1: + return dg + dgn = _recursively_calc_power_grad(g, dg, power - 1) + return g @ dgn + dg @ g.ipower(power - 1) + + +@given(gates_and_params(), integers(min_value=-10, max_value=10)) +def test_power_gate(g_and_p: tuple[Gate, RealVector], power: int) -> None: + gate, params = g_and_p + pgate = PowerGate(gate, power) + actual_unitary = pgate.get_unitary(params) + expected_unitary = gate.get_unitary(params).ipower(power) + assert actual_unitary.isclose(expected_unitary) + + if not isinstance(gate, DifferentiableUnitary): + return + + if gate.num_params == 0: + return + + actual_grad = pgate.get_grad(params) + expected_grad = _recursively_calc_power_grad( + gate.get_unitary(params), + gate.get_grad(params), + power, + ) + assert np.allclose(actual_grad, expected_grad) diff --git a/tests/ir/gates/parameterized/test_pauliz.py b/tests/ir/gates/parameterized/test_pauliz.py new file mode 100644 index 000000000..61ee78811 --- /dev/null +++ b/tests/ir/gates/parameterized/test_pauliz.py @@ -0,0 +1,61 @@ +"""This module tests the PauliZGate class.""" +from __future__ import annotations + +import numpy as np +import pytest +from hypothesis import given +from hypothesis.strategies import floats +from hypothesis.strategies import integers + +from bqskit.ir.gates import IdentityGate +from bqskit.ir.gates import PauliZGate +from bqskit.ir.gates import RZGate +from bqskit.ir.gates import RZZGate +from bqskit.utils.test.strategies import num_qudits + + +class TestInit: + @given(num_qudits(4)) + def test_valid(self, num_qudits: int) -> None: + g = PauliZGate(num_qudits) + assert g.num_qudits == num_qudits + assert g.num_params == 2 ** num_qudits + identity = np.identity(2 ** num_qudits) + assert g.get_unitary([0] * 2 ** num_qudits) == identity + + @given(integers(max_value=0)) + def test_invalid(self, num_qudits: int) -> None: + with pytest.raises(ValueError): + PauliZGate(num_qudits) + + +class TestGetUnitary: + @given(floats(allow_nan=False, allow_infinity=False, width=16)) + def test_i(self, angle: float) -> None: + g = PauliZGate(1) + i = IdentityGate(1).get_unitary() + dist = g.get_unitary([angle, 0]).get_distance_from(i) + assert dist < 1e-7 + + @given(floats(allow_nan=False, allow_infinity=False, width=16)) + def test_z(self, angle: float) -> None: + g = PauliZGate(1) + z = RZGate() + assert g.get_unitary([0, angle]) == z.get_unitary([angle]) + + @given(floats(allow_nan=False, allow_infinity=False, width=16)) + def test_zz(self, angle: float) -> None: + g = PauliZGate(2) + zz = RZZGate() + params = [0.0] * 4 + params[3] = angle + assert g.get_unitary(params) == zz.get_unitary([angle]) + + +@given(floats(allow_nan=False, allow_infinity=False, width=16)) +def test_optimize(angle: float) -> None: + g = PauliZGate(1) + z = RZGate() + utry = z.get_unitary([angle]) + params = g.optimize(np.array(utry)) + assert g.get_unitary(params).get_distance_from(utry.conj().T) < 1e-7 diff --git a/tests/passes/control/predicates/test_diagonal.py b/tests/passes/control/predicates/test_diagonal.py new file mode 100644 index 000000000..f20f32009 --- /dev/null +++ b/tests/passes/control/predicates/test_diagonal.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from itertools import combinations +from random import choices + +import numpy as np +from hypothesis import given +from hypothesis.strategies import integers + +from bqskit.compiler.passdata import PassData +from bqskit.ir.circuit import Circuit +from bqskit.ir.gates import CNOTGate +from bqskit.ir.gates import HGate +from bqskit.ir.gates import RZGate +from bqskit.ir.gates import SXGate +from bqskit.passes.control.predicates.diagonal import DiagonalPredicate + + +def phase_gadget() -> Circuit: + gadget = Circuit(2) + gadget.append_gate(CNOTGate(), (0, 1)) + gadget.append_gate(RZGate(), (1), [np.random.normal()]) + gadget.append_gate(CNOTGate(), (0, 1)) + return gadget + + +@given(integers(2, 6), integers(0, 10)) +def test_diagonal_predicate(num_qudits: int, num_gadgets: int) -> None: + circuit = Circuit(num_qudits) + all_locations = list(combinations(range(num_qudits), r=2)) + locations = choices(all_locations, k=num_gadgets) + for location in locations: + circuit.append_circuit(phase_gadget(), location) + data = PassData(circuit) + pred = DiagonalPredicate(1e-5) + + is_diagonal = True + assert pred.get_truth_value(circuit, data) == is_diagonal + + circuit.append_gate(HGate(), (0)) + assert not pred.get_truth_value(circuit, data) == is_diagonal + + +@given(integers(1, 10)) +def test_single_qubit_diagonal_predicate(exponent: int) -> None: + angle = 10 ** - exponent + circuit = Circuit(1) + circuit.append_gate(RZGate(), (0), [angle]) + circuit.append_gate(SXGate(), (0)) + circuit.append_gate(RZGate(), (0), [np.random.normal()]) + circuit.append_gate(SXGate(), (0)) + circuit.append_gate(RZGate(), (0), [angle]) + + pred = DiagonalPredicate(1e-5) + data = PassData(circuit) + # This is true by the small angle approximation + pred.get_truth_value(circuit, data) == (angle < 1e-5) diff --git a/tests/passes/control/test_paralleldo.py b/tests/passes/control/test_paralleldo.py index 2d69b5e70..91a5c4f72 100644 --- a/tests/passes/control/test_paralleldo.py +++ b/tests/passes/control/test_paralleldo.py @@ -38,11 +38,11 @@ async def run(self, circuit: Circuit, data: PassData) -> None: data['key'] = '1' -class Sleep3Pass(BasePass): +class Sleep9Pass(BasePass): async def run(self, circuit: Circuit, data: PassData) -> None: circuit.append_gate(ZGate(), 0) - time.sleep(0.3) - data['key'] = '3' + time.sleep(0.9) + data['key'] = '9' def pick_z(c1: Circuit, c2: Circuit) -> bool: @@ -66,7 +66,7 @@ def test_parallel_do_no_passes() -> None: def test_parallel_do_pick_first(compiler: Compiler) -> None: - passes: list[list[BasePass]] = [[Sleep3Pass()], [Sleep1Pass()]] + passes: list[list[BasePass]] = [[Sleep9Pass()], [Sleep1Pass()]] pd_pass = ParallelDo(passes, pick_z, True) _, data = compiler.compile(Circuit(1), pd_pass, True) assert data['key'] == '1' diff --git a/tests/passes/search/generators/test_discrete.py b/tests/passes/search/generators/test_discrete.py new file mode 100644 index 000000000..dedaaa18a --- /dev/null +++ b/tests/passes/search/generators/test_discrete.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from random import randint + +from bqskit.compiler.passdata import PassData +from bqskit.ir.circuit import Circuit +from bqskit.ir.gates import CNOTGate +from bqskit.ir.gates import HGate +from bqskit.ir.gates import TGate +from bqskit.passes.search.generators.discrete import DiscreteLayerGenerator + + +class TestDiscreteLayerGenerator: + + def test_gate_set(self) -> None: + gates = [HGate(), CNOTGate(), TGate()] + generator = DiscreteLayerGenerator() + assert all(g in generator.gateset for g in gates) + + def test_double_headed(self) -> None: + single_gen = DiscreteLayerGenerator(double_headed=False) + double_gen = DiscreteLayerGenerator(double_headed=True) + base = Circuit(4) + single_sucs = single_gen.gen_successors(base, PassData(base)) + double_sucs = double_gen.gen_successors(base, PassData(base)) + assert len(single_sucs) == len(double_sucs) + + base = Circuit(2) + base.append_gate(CNOTGate(), (0, 1)) + single_sucs = single_gen.gen_successors(base, PassData(base)) + double_sucs = double_gen.gen_successors(base, PassData(base)) + assert len(single_sucs) < len(double_sucs) + assert all(c in double_sucs for c in single_sucs) + + def test_cancels_something(self) -> None: + gen = DiscreteLayerGenerator() + base = Circuit(2) + base.append_gate(HGate(), (0,)) + base.append_gate(TGate(), (0,)) + base.append_gate(HGate(), (0,)) + assert gen.cancels_something(base, HGate(), (0,)) + assert not gen.cancels_something(base, HGate(), (1,)) + assert not gen.cancels_something(base, TGate(), (0,)) + + def test_count_repeats(self) -> None: + num_repeats = randint(1, 50) + c = Circuit(1) + for _ in range(num_repeats): + c.append_gate(HGate(), (0,)) + gen = DiscreteLayerGenerator() + assert gen.count_repeats(c, HGate(), 0) == num_repeats diff --git a/tests/passes/synthesis/test_diagonal.py b/tests/passes/synthesis/test_diagonal.py new file mode 100644 index 000000000..e72b6f500 --- /dev/null +++ b/tests/passes/synthesis/test_diagonal.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from numpy.random import normal +from scipy.linalg import expm + +from bqskit.compiler import Compiler +from bqskit.ir.circuit import Circuit +from bqskit.passes.synthesis.diagonal import WalshDiagonalSynthesisPass +from bqskit.qis import UnitaryMatrix +from bqskit.qis.pauliz import PauliZMatrices + + +class TestWalshDiagonalSynthesis: + + def test_1_qubit(self, compiler: Compiler) -> None: + num_qubits = 1 + pauliz = PauliZMatrices(num_qubits) + vector = [normal() for _ in range(len(pauliz))] + H_matrix = pauliz.dot_product(vector) + utry = UnitaryMatrix(expm(1j * H_matrix)) + + circuit = Circuit.from_unitary(utry) + synthesis = WalshDiagonalSynthesisPass() + circuit = compiler.compile(circuit, [synthesis]) + dist = circuit.get_unitary().get_distance_from(utry) + + assert dist <= 1e-5 + + def test_2_qubit(self, compiler: Compiler) -> None: + num_qubits = 2 + pauliz = PauliZMatrices(num_qubits) + vector = [normal() for _ in range(len(pauliz))] + H_matrix = pauliz.dot_product(vector) + utry = UnitaryMatrix(expm(1j * H_matrix)) + + circuit = Circuit.from_unitary(utry) + synthesis = WalshDiagonalSynthesisPass() + circuit = compiler.compile(circuit, [synthesis]) + dist = circuit.get_unitary().get_distance_from(utry) + + assert dist <= 1e-5 + + def test_3_qubit(self, compiler: Compiler) -> None: + num_qubits = 3 + pauliz = PauliZMatrices(num_qubits) + vector = [normal() for _ in range(len(pauliz))] + H_matrix = pauliz.dot_product(vector) + utry = UnitaryMatrix(expm(1j * H_matrix)) + + circuit = Circuit.from_unitary(utry) + synthesis = WalshDiagonalSynthesisPass() + circuit = compiler.compile(circuit, [synthesis]) + dist = circuit.get_unitary().get_distance_from(utry) + + assert dist <= 1e-5 + + def test_4_qubit(self, compiler: Compiler) -> None: + num_qubits = 4 + pauliz = PauliZMatrices(num_qubits) + vector = [normal() for _ in range(len(pauliz))] + H_matrix = pauliz.dot_product(vector) + utry = UnitaryMatrix(expm(1j * H_matrix)) + + circuit = Circuit.from_unitary(utry) + synthesis = WalshDiagonalSynthesisPass() + circuit = compiler.compile(circuit, [synthesis]) + dist = circuit.get_unitary().get_distance_from(utry) + + assert dist <= 1e-5 + + def test_5_qubit(self, compiler: Compiler) -> None: + num_qubits = 5 + pauliz = PauliZMatrices(num_qubits) + vector = [normal() for _ in range(len(pauliz))] + H_matrix = pauliz.dot_product(vector) + utry = UnitaryMatrix(expm(1j * H_matrix)) + + circuit = Circuit.from_unitary(utry) + synthesis = WalshDiagonalSynthesisPass() + circuit = compiler.compile(circuit, [synthesis]) + dist = circuit.get_unitary().get_distance_from(utry) + + assert dist <= 1e-5 diff --git a/tests/qis/test_graph.py b/tests/qis/test_graph.py index d4ea681df..158238959 100644 --- a/tests/qis/test_graph.py +++ b/tests/qis/test_graph.py @@ -1,9 +1,217 @@ """This module tests the CouplingGraph class.""" from __future__ import annotations +from typing import Any + import pytest from bqskit.qis.graph import CouplingGraph +from bqskit.qis.graph import CouplingGraphLike + + +def test_coupling_graph_init_valid() -> None: + # Test with valid inputs + graph = {(0, 1), (1, 2), (2, 3)} + num_qudits = 4 + remote_edges = [(1, 2)] + default_weight = 1.0 + default_remote_weight = 10.0 + edge_weights_overrides = {(1, 2): 0.5} + + coupling_graph = CouplingGraph( + graph, + num_qudits, + remote_edges, + default_weight, + default_remote_weight, + edge_weights_overrides, + ) + + assert coupling_graph.num_qudits == num_qudits + assert coupling_graph._edges == graph + assert coupling_graph._remote_edges == set(remote_edges) + assert coupling_graph.default_weight == default_weight + assert coupling_graph.default_remote_weight == default_remote_weight + assert all( + coupling_graph._mat[q1][q2] == weight + for (q1, q2), weight in edge_weights_overrides.items() + ) + + +@pytest.mark.parametrize( + 'graph, num_qudits, remote_edges, default_weight, default_remote_weight,' + ' edge_weights_overrides, expected_exception', + [ + # Invalid graph + (None, 4, [], 1.0, 100.0, {}, TypeError), + # num_qudits is not an integer + ({(0, 1)}, '4', [], 1.0, 100.0, {}, TypeError), + # num_qudits is negative + ({(0, 1)}, -1, [], 1.0, 100.0, {}, ValueError), + # Invalid remote_edges + ({(0, 1)}, 4, None, 1.0, 100.0, {}, TypeError), + # Remote edge not in graph + ({(0, 1)}, 4, [(1, 2)], 1.0, 100.0, {}, ValueError), + # Invalid default_weight + ({(0, 1)}, 4, [], '1.0', 100.0, {}, TypeError), + # Invalid default_remote_weight + ({(0, 1)}, 4, [], 1.0, '100.0', {}, TypeError), + # Invalid edge_weights_overrides + ({(0, 1)}, 4, [], 1.0, 100.0, None, TypeError), + # Non-integer value in edge_weights_overrides + ({(0, 1)}, 4, [], 1.0, 100.0, {(0, 1): '0.5'}, TypeError), + # Edge in edge_weights_overrides not in graph + ({(0, 1)}, 4, [], 1.0, 100.0, {(1, 2): 0.5}, ValueError), + ], +) +def test_coupling_graph_init_invalid( + graph: CouplingGraphLike, + num_qudits: Any, + remote_edges: Any, + default_weight: Any, + default_remote_weight: Any, + edge_weights_overrides: Any, + expected_exception: Exception, +) -> None: + with pytest.raises(expected_exception): + CouplingGraph( + graph, + num_qudits, + remote_edges, + default_weight, + default_remote_weight, + edge_weights_overrides, + ) + + +def test_get_qpu_to_qudit_map_single_qpu() -> None: + graph = CouplingGraph([(0, 1), (1, 2), (2, 3)]) + expected_map = [[0, 1, 2, 3]] + assert graph.get_qpu_to_qudit_map() == expected_map + + +def test_get_qpu_to_qudit_map_multiple_qpus() -> None: + graph = CouplingGraph([(0, 1), (1, 2), (2, 3)], remote_edges=[(1, 2)]) + expected_map = [[0, 1], [2, 3]] + assert graph.get_qpu_to_qudit_map() == expected_map + + +def test_get_qpu_to_qudit_map_disconnected() -> None: + graph = CouplingGraph([(0, 1), (1, 2), (3, 4)], remote_edges=[(1, 2)]) + expected_map = [[0, 1], [2], [3, 4]] + assert graph.get_qpu_to_qudit_map() == expected_map + + +def test_get_qpu_to_qudit_map_empty_graph() -> None: + graph = CouplingGraph([]) + expected_map = [[0]] + assert graph.get_qpu_to_qudit_map() == expected_map + + +def test_get_qpu_to_qudit_map_complex_topology() -> None: + graph = CouplingGraph( + [(0, 1), (1, 2), (0, 2), (2, 5), (3, 4), (4, 5), (3, 5)], + remote_edges=[(2, 5)], + ) + expected_map = [[0, 1, 2], [3, 4, 5]] + assert graph.get_qpu_to_qudit_map() == expected_map + + +def test_get_qudit_to_qpu_map_three_qpu() -> None: + graph = CouplingGraph( + [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)], + remote_edges=[(2, 3), (5, 6)], + ) + expected_map = [[0, 1, 2], [3, 4, 5], [6, 7]] + assert graph.get_qpu_to_qudit_map() == expected_map + + +def test_is_distributed() -> None: + graph = CouplingGraph([(0, 1), (1, 2), (2, 3)]) + assert not graph.is_distributed() + + graph = CouplingGraph([(0, 1), (1, 2), (2, 3)], remote_edges=[(1, 2)]) + assert graph.is_distributed() + + graph = CouplingGraph([(0, 1), (1, 2), (2, 3)], remote_edges=[(1, 2)]) + assert graph.is_distributed() + + graph = CouplingGraph( + [(0, 1), (1, 2), (2, 3)], + remote_edges=[(1, 2), (2, 3)], + ) + assert graph.is_distributed() + + +def test_qpu_count() -> None: + graph = CouplingGraph([(0, 1), (1, 2), (2, 3)]) + assert graph.qpu_count() == 1 + + graph = CouplingGraph([(0, 1), (1, 2), (2, 3)], remote_edges=[(1, 2)]) + assert graph.qpu_count() == 2 + + graph = CouplingGraph( + [(0, 1), (1, 2), (2, 3)], + remote_edges=[(1, 2), (2, 3)], + ) + assert graph.qpu_count() == 3 + + graph = CouplingGraph([]) + assert graph.qpu_count() == 1 + + +def test_get_individual_qpu_graphs() -> None: + graph = CouplingGraph([(0, 1), (1, 2), (2, 3)]) + qpus = graph.get_individual_qpu_graphs() + assert len(qpus) == 1 + assert qpus[0] == graph + + graph = CouplingGraph([(0, 1), (1, 2), (2, 3)], remote_edges=[(1, 2)]) + qpus = graph.get_individual_qpu_graphs() + assert len(qpus) == 2 + assert qpus[0] == CouplingGraph([(0, 1)]) + assert qpus[1] == CouplingGraph([(0, 1)]) + + graph = CouplingGraph( + [(0, 1), (1, 2), (2, 3)], + remote_edges=[(1, 2), (2, 3)], + ) + qpus = graph.get_individual_qpu_graphs() + assert len(qpus) == 3 + assert qpus[0] == CouplingGraph([(0, 1)]) + assert qpus[1] == CouplingGraph([]) + assert qpus[2] == CouplingGraph([]) + + +def test_get_qudit_to_qpu_map() -> None: + graph = CouplingGraph([(0, 1), (1, 2), (2, 3)]) + assert graph.get_qudit_to_qpu_map() == [0, 0, 0, 0] + + graph = CouplingGraph([(0, 1), (1, 2), (2, 3)], remote_edges=[(1, 2)]) + assert graph.get_qudit_to_qpu_map() == [0, 0, 1, 1] + + graph = CouplingGraph( + [(0, 1), (1, 2), (2, 3)], + remote_edges=[(1, 2), (2, 3)], + ) + assert graph.get_qudit_to_qpu_map() == [0, 0, 1, 2] + + graph = CouplingGraph([]) + assert graph.get_qudit_to_qpu_map() == [0] + + +def test_get_qpu_connectivity() -> None: + graph = CouplingGraph([(0, 1), (1, 2), (2, 3)]) + assert graph.get_qpu_connectivity() == [set()] + + graph = CouplingGraph([(0, 1), (1, 2), (2, 3)], remote_edges=[(1, 2)]) + assert graph.get_qpu_connectivity() == [{1}, {0}] + + graph = CouplingGraph( + [(0, 1), (1, 2), (2, 3)], + remote_edges=[(1, 2), (2, 3)], + ) + assert graph.get_qpu_connectivity() == [{1}, {0, 2}, {1}] class TestGraphGetSubgraphsOfSize: @@ -63,3 +271,14 @@ def test_invalid(self) -> None: with pytest.raises(TypeError): coupling_graph.get_subgraph('a') # type: ignore + + +def test_is_linear() -> None: + coupling_graph = CouplingGraph({(0, 1), (1, 2), (2, 3)}) + assert coupling_graph.is_linear() + + coupling_graph = CouplingGraph({(0, 1), (1, 2), (0, 3), (2, 3)}) + assert not coupling_graph.is_linear() + + coupling_graph = CouplingGraph.all_to_all(4) + assert not coupling_graph.is_linear() diff --git a/tests/qis/test_pauli.py b/tests/qis/test_pauli.py index 8e8c3ef4f..261566821 100644 --- a/tests/qis/test_pauli.py +++ b/tests/qis/test_pauli.py @@ -10,6 +10,7 @@ from hypothesis.strategies import integers from bqskit.qis.pauli import PauliMatrices +from bqskit.qis.pauliz import PauliZMatrices from bqskit.qis.unitary.unitary import RealVector from bqskit.utils.test.types import invalid_type_test from bqskit.utils.test.types import valid_type_test @@ -800,3 +801,472 @@ def test_multi( assert all(isinstance(pauli, np.ndarray) for pauli in paulis) assert len(paulis) == len(pauli_mats) assert all(self.in_array(pauli, pauli_mats) for pauli in paulis) + + +class TestPauliZMatricesConstructor: + def in_array(self, needle: Any, haystack: Any) -> bool: + for elem in haystack: + if np.allclose(elem, needle): + return True + + return False + + @invalid_type_test(PauliZMatrices) + def test_invalid_type(self) -> None: + pass + + @given(integers(max_value=-1)) + def test_invalid_value(self, size: int) -> None: + with pytest.raises(ValueError): + PauliZMatrices(size) + + def test_size_1(self) -> None: + num_qubits = 1 + paulis = PauliZMatrices(num_qubits) + assert len(paulis) == 2 ** num_qubits + + I = np.array([[1, 0], [0, 1]], dtype=np.complex128) + Z = np.array([[1, 0], [0, -1]], dtype=np.complex128) + + assert self.in_array(I, paulis) + assert self.in_array(Z, paulis) + + def test_size_2(self) -> None: + num_qubits = 2 + paulis = PauliZMatrices(num_qubits) + assert len(paulis) == 2 ** num_qubits + + I = np.array([[1, 0], [0, 1]], dtype=np.complex128) + Z = np.array([[1, 0], [0, -1]], dtype=np.complex128) + + assert self.in_array(np.kron(Z, Z), paulis) + assert self.in_array(np.kron(Z, I), paulis) + assert self.in_array(np.kron(I, Z), paulis) + assert self.in_array(np.kron(I, I), paulis) + + def test_size_3(self) -> None: + num_qubits = 3 + paulis = PauliZMatrices(num_qubits) + assert len(paulis) == 2 ** num_qubits + + I = np.array([[1, 0], [0, 1]], dtype=np.complex128) + Z = np.array([[1, 0], [0, -1]], dtype=np.complex128) + + assert self.in_array(np.kron(Z, np.kron(Z, Z)), paulis) + assert self.in_array(np.kron(Z, np.kron(Z, I)), paulis) + assert self.in_array(np.kron(Z, np.kron(I, Z)), paulis) + assert self.in_array(np.kron(Z, np.kron(I, I)), paulis) + assert self.in_array(np.kron(I, np.kron(Z, Z)), paulis) + assert self.in_array(np.kron(I, np.kron(Z, I)), paulis) + assert self.in_array(np.kron(I, np.kron(I, Z)), paulis) + assert self.in_array(np.kron(I, np.kron(I, I)), paulis) + + +class TestPauliZMatricesGetProjectionMatrices: + def in_array(self, needle: Any, haystack: Any) -> bool: + for elem in haystack: + if np.allclose(elem, needle): + return True + + return False + + @valid_type_test(PauliZMatrices(1).get_projection_matrices) + def test_valid_type(self) -> None: + pass + + @invalid_type_test(PauliZMatrices(1).get_projection_matrices) + def test_invalid_type(self) -> None: + pass + + @pytest.mark.parametrize('invalid_qubit', [-5, -2, 4, 10]) + def test_invalid_value_1(self, invalid_qubit: int) -> None: + paulis = PauliZMatrices(4) + with pytest.raises(ValueError): + paulis.get_projection_matrices([invalid_qubit]) + + @pytest.mark.parametrize('invalid_q_set', [[0, 0], [0, 1, 2, 4]]) + def test_invalid_value_2(self, invalid_q_set: list[int]) -> None: + paulis = PauliZMatrices(4) + with pytest.raises(ValueError): + paulis.get_projection_matrices(invalid_q_set) + + def test_proj_3_0(self) -> None: + num_qubits = 3 + qubit_proj = 0 + paulis = PauliZMatrices(num_qubits) + projs = paulis.get_projection_matrices([qubit_proj]) + assert len(projs) == 2 + + I = np.array([[1, 0], [0, 1]], dtype=np.complex128) + Z = np.array([[1, 0], [0, -1]], dtype=np.complex128) + + assert self.in_array(np.kron(np.kron(Z, I), I), projs) + assert self.in_array(np.kron(np.kron(I, I), I), projs) + + def test_proj_3_1(self) -> None: + num_qubits = 3 + qubit_proj = 1 + paulis = PauliZMatrices(num_qubits) + projs = paulis.get_projection_matrices([qubit_proj]) + assert len(projs) == 2 + + I = np.array([[1, 0], [0, 1]], dtype=np.complex128) + Z = np.array([[1, 0], [0, -1]], dtype=np.complex128) + + assert self.in_array(np.kron(np.kron(I, Z), I), projs) + assert self.in_array(np.kron(np.kron(I, I), I), projs) + + def test_proj_3_2(self) -> None: + num_qubits = 3 + qubit_proj = 2 + paulis = PauliZMatrices(num_qubits) + projs = paulis.get_projection_matrices([qubit_proj]) + assert len(projs) == 2 + + I = np.array([[1, 0], [0, 1]], dtype=np.complex128) + Z = np.array([[1, 0], [0, -1]], dtype=np.complex128) + + assert self.in_array(np.kron(np.kron(I, I), Z), projs) + assert self.in_array(np.kron(np.kron(I, I), I), projs) + + def test_proj_4_0(self) -> None: + num_qubits = 4 + qubit_proj = 0 + paulis = PauliZMatrices(num_qubits) + projs = paulis.get_projection_matrices([qubit_proj]) + assert len(projs) == 2 + + I = np.array([[1, 0], [0, 1]], dtype=np.complex128) + Z = np.array([[1, 0], [0, -1]], dtype=np.complex128) + + assert self.in_array(np.kron(np.kron(np.kron(Z, I), I), I), projs) + assert self.in_array(np.kron(np.kron(np.kron(I, I), I), I), projs) + + def test_proj_4_1(self) -> None: + num_qubits = 4 + qubit_proj = 1 + paulis = PauliZMatrices(num_qubits) + projs = paulis.get_projection_matrices([qubit_proj]) + assert len(projs) == 2 + + I = np.array([[1, 0], [0, 1]], dtype=np.complex128) + Z = np.array([[1, 0], [0, -1]], dtype=np.complex128) + + assert self.in_array(np.kron(np.kron(np.kron(I, Z), I), I), projs) + assert self.in_array(np.kron(np.kron(np.kron(I, I), I), I), projs) + + def test_proj_4_2(self) -> None: + num_qubits = 4 + qubit_proj = 2 + paulis = PauliZMatrices(num_qubits) + projs = paulis.get_projection_matrices([qubit_proj]) + assert len(projs) == 2 + + I = np.array([[1, 0], [0, 1]], dtype=np.complex128) + Z = np.array([[1, 0], [0, -1]], dtype=np.complex128) + + assert self.in_array(np.kron(np.kron(np.kron(I, I), Z), I), projs) + assert self.in_array(np.kron(np.kron(np.kron(I, I), I), I), projs) + + def test_proj_4_3(self) -> None: + num_qubits = 4 + qubit_proj = 3 + paulis = PauliZMatrices(num_qubits) + projs = paulis.get_projection_matrices([qubit_proj]) + assert len(projs) == 2 + + I = np.array([[1, 0], [0, 1]], dtype=np.complex128) + Z = np.array([[1, 0], [0, -1]], dtype=np.complex128) + + assert self.in_array(np.kron(np.kron(np.kron(I, I), I), Z), projs) + assert self.in_array(np.kron(np.kron(np.kron(I, I), I), I), projs) + + def test_proj_3_01(self) -> None: + num_qubits = 3 + qubit_pro1 = 0 + qubit_pro2 = 1 + paulis = PauliZMatrices(num_qubits) + projs = paulis.get_projection_matrices([qubit_pro1, qubit_pro2]) + assert len(projs) == 4 + + I = np.array([[1, 0], [0, 1]], dtype=np.complex128) + Z = np.array([[1, 0], [0, -1]], dtype=np.complex128) + + assert self.in_array(np.kron(np.kron(Z, I), I), projs) + assert self.in_array(np.kron(np.kron(I, I), I), projs) + assert self.in_array(np.kron(np.kron(Z, Z), I), projs) + assert self.in_array(np.kron(np.kron(I, Z), I), projs) + + def test_proj_3_02(self) -> None: + num_qubits = 3 + qubit_pro1 = 0 + qubit_pro2 = 2 + paulis = PauliZMatrices(num_qubits) + projs = paulis.get_projection_matrices([qubit_pro1, qubit_pro2]) + assert len(projs) == 4 + + I = np.array([[1, 0], [0, 1]], dtype=np.complex128) + Z = np.array([[1, 0], [0, -1]], dtype=np.complex128) + + assert self.in_array(np.kron(np.kron(Z, I), I), projs) + assert self.in_array(np.kron(np.kron(I, I), I), projs) + assert self.in_array(np.kron(np.kron(Z, I), Z), projs) + assert self.in_array(np.kron(np.kron(I, I), Z), projs) + + def test_proj_3_12(self) -> None: + num_qubits = 3 + qubit_pro1 = 1 + qubit_pro2 = 2 + paulis = PauliZMatrices(num_qubits) + projs = paulis.get_projection_matrices([qubit_pro1, qubit_pro2]) + assert len(projs) == 4 + + I = np.array([[1, 0], [0, 1]], dtype=np.complex128) + Z = np.array([[1, 0], [0, -1]], dtype=np.complex128) + + assert self.in_array(np.kron(np.kron(I, Z), I), projs) + assert self.in_array(np.kron(np.kron(I, I), I), projs) + assert self.in_array(np.kron(np.kron(I, Z), Z), projs) + assert self.in_array(np.kron(np.kron(I, I), Z), projs) + + def test_proj_3_012(self) -> None: + num_qubits = 3 + paulis = PauliZMatrices(num_qubits) + projs = paulis.get_projection_matrices([0, 1, 2]) + assert len(projs) == 8 + + I = np.array([[1, 0], [0, 1]], dtype=np.complex128) + Z = np.array([[1, 0], [0, -1]], dtype=np.complex128) + + assert self.in_array(np.kron(np.kron(I, Z), I), projs) + assert self.in_array(np.kron(np.kron(I, I), I), projs) + assert self.in_array(np.kron(np.kron(I, Z), Z), projs) + assert self.in_array(np.kron(np.kron(I, I), Z), projs) + + assert self.in_array(np.kron(np.kron(Z, Z), I), projs) + assert self.in_array(np.kron(np.kron(Z, I), I), projs) + assert self.in_array(np.kron(np.kron(Z, Z), Z), projs) + assert self.in_array(np.kron(np.kron(Z, I), Z), projs) + + def test_proj_4_02(self) -> None: + num_qubits = 4 + qubit_pro1 = 0 + qubit_pro2 = 2 + paulis = PauliZMatrices(num_qubits) + projs = paulis.get_projection_matrices([qubit_pro1, qubit_pro2]) + assert len(projs) == 4 + + I = np.array([[1, 0], [0, 1]], dtype=np.complex128) + Z = np.array([[1, 0], [0, -1]], dtype=np.complex128) + + assert self.in_array(np.kron(np.kron(np.kron(Z, I), I), I), projs) + assert self.in_array(np.kron(np.kron(np.kron(I, I), I), I), projs) + assert self.in_array(np.kron(np.kron(np.kron(Z, I), Z), I), projs) + assert self.in_array(np.kron(np.kron(np.kron(I, I), Z), I), projs) + + +class TestPauliZMatricesDotProduct: + @pytest.mark.parametrize('invalid_alpha', [[1.1] * i for i in range(2)]) + def test_invalid_value(self, invalid_alpha: RealVector) -> None: + with pytest.raises(ValueError): + PauliZMatrices(1).dot_product(invalid_alpha) + + @pytest.mark.parametrize( + 'alpha, prod', [ + ([1, 0], PauliZMatrices.I), + ([0, 1], PauliZMatrices.Z), + ([1, 1], PauliZMatrices.I + PauliZMatrices.Z), + ], + ) + def test_size_1(self, alpha: RealVector, prod: npt.NDArray[Any]) -> None: + assert np.allclose(PauliZMatrices(1).dot_product(alpha), prod) + + @pytest.mark.parametrize( + 'alpha, prod', [ + ( + [1, 0, 0, 0], + np.kron(PauliZMatrices.I, PauliZMatrices.I), + ), + ( + [0, 1, 0, 0], + np.kron(PauliZMatrices.I, PauliZMatrices.Z), + ), + ( + [0, 0, 1, 0], + np.kron(PauliZMatrices.Z, PauliZMatrices.I), + ), + ( + [0, 0, 0, 1], + np.kron(PauliZMatrices.Z, PauliZMatrices.Z), + ), + ( + [1, 0, 0, 1], + np.kron(PauliZMatrices.I, PauliZMatrices.I) + + np.kron(PauliZMatrices.Z, PauliZMatrices.Z), + ), + ( + [1.8, 0, 0, 91.7], + 1.8 * np.kron(PauliZMatrices.I, PauliZMatrices.I) + + 91.7 * np.kron(PauliZMatrices.Z, PauliZMatrices.Z), + ), + ], + ) + def test_size_2( + self, alpha: RealVector, + prod: npt.NDArray[np.complex128], + ) -> None: + assert np.allclose(PauliZMatrices(2).dot_product(alpha), prod) + + +class TestPauliZMatricesFromString: + def in_array(self, needle: Any, haystack: Any) -> bool: + for elem in haystack: + if not needle.shape == elem.shape: + continue + if np.allclose(elem, needle): + return True + + return False + + @valid_type_test(PauliZMatrices.from_string) + def test_valid_type(self) -> None: + pass + + @invalid_type_test(PauliZMatrices.from_string) + def test_invalid_type(self) -> None: + pass + + @pytest.mark.parametrize( + 'invalid_str', [ + 'ABC', + 'IXYZA', + '\t AIXYZ ,, \n\r\tabc\t', + 'IXYZ+', + 'IXYZ, IXA', + 'WXYZ, XYZ', + ], + ) + def test_invalid_value(self, invalid_str: str) -> None: + with pytest.raises(ValueError): + PauliZMatrices.from_string(invalid_str) + + @pytest.mark.parametrize( + 'pauli_str, pauli_mat', [ + ( + 'IZZ', + np.kron( + np.kron( + PauliZMatrices.I, + PauliZMatrices.Z, + ), + PauliZMatrices.Z, + ), + ), + ( + 'ZIZ', + np.kron( + np.kron( + PauliZMatrices.Z, + PauliZMatrices.I, + ), + PauliZMatrices.Z, + ), + ), + ( + 'ZZI', + np.kron( + np.kron( + PauliZMatrices.Z, + PauliZMatrices.Z, + ), + PauliZMatrices.I, + ), + ), + ('\t ZZ ,,\n\r\t\t', np.kron(PauliZMatrices.Z, PauliZMatrices.Z)), + ], + ) + def test_single( + self, + pauli_str: str, + pauli_mat: npt.NDArray[np.complex128], + ) -> None: + assert isinstance(PauliZMatrices.from_string(pauli_str), np.ndarray) + assert np.allclose( + np.array(PauliZMatrices.from_string(pauli_str)), + pauli_mat, + ) + + @pytest.mark.parametrize( + 'pauli_str, pauli_mats', [ + ( + 'IIZ, IIZ', [ + np.kron( + np.kron( + PauliZMatrices.I, + PauliZMatrices.I, + ), + PauliZMatrices.Z, + ), + np.kron( + np.kron( + PauliZMatrices.I, + PauliZMatrices.I, + ), + PauliZMatrices.Z, + ), + ], + ), + ( + 'ZIZ, ZZI', [ + np.kron( + np.kron( + PauliZMatrices.Z, + PauliZMatrices.I, + ), + PauliZMatrices.Z, + ), + np.kron( + np.kron( + PauliZMatrices.Z, + PauliZMatrices.Z, + ), + PauliZMatrices.I, + ), + ], + ), + ( + 'IIZ, IZI, ZZZ', [ + np.kron( + np.kron( + PauliZMatrices.I, + PauliZMatrices.I, + ), + PauliZMatrices.Z, + ), + np.kron( + np.kron( + PauliZMatrices.I, + PauliZMatrices.Z, + ), + PauliZMatrices.I, + ), + np.kron( + np.kron( + PauliZMatrices.Z, + PauliZMatrices.Z, + ), + PauliZMatrices.Z, + ), + ], + ), + ], + ) + def test_multi( + self, pauli_str: str, + pauli_mats: list[npt.NDArray[np.complex128]], + ) -> None: + paulis = PauliZMatrices.from_string(pauli_str) + assert isinstance(paulis, list) + assert all(isinstance(pauli, np.ndarray) for pauli in paulis) + assert len(paulis) == len(pauli_mats) + assert all(self.in_array(pauli, pauli_mats) for pauli in paulis) diff --git a/tests/qis/unitary/test_props.py b/tests/qis/unitary/test_props.py new file mode 100644 index 000000000..2df433a69 --- /dev/null +++ b/tests/qis/unitary/test_props.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +from bqskit.ir.circuit import Circuit + + +def test_circuit_dim_overflow() -> None: + c = Circuit(1024) + assert c.dim != 0 diff --git a/tests/qis/unitary/test_unitarymatrix.py b/tests/qis/unitary/test_unitarymatrix.py index 046057c45..045ddc433 100644 --- a/tests/qis/unitary/test_unitarymatrix.py +++ b/tests/qis/unitary/test_unitarymatrix.py @@ -216,3 +216,24 @@ def test_scalar_multiplication(self, u: UnitaryMatrix, a: float) -> None: out2 = a * u assert out2 is not u assert not isinstance(out2, UnitaryMatrix) + + +@given(unitaries(), integers(min_value=-10, max_value=10)) +def test_ipower(u: UnitaryMatrix, n: int) -> None: + out = u.ipower(n) + if n == 0: + assert out == UnitaryMatrix.identity(u.dim, u.radixes) + elif n == 1: + assert out == u + elif n == -1: + assert out == u.dagger + elif n < 0: + acm = u.dagger + for _ in range(-n - 1): + acm = acm @ u.dagger + assert out == acm + else: + acm = u + for _ in range(n - 1): + acm = acm @ u + assert out == acm diff --git a/tests/runtime/test_attached.py b/tests/runtime/test_attached.py index 0c2ecb67e..aa155dc3c 100644 --- a/tests/runtime/test_attached.py +++ b/tests/runtime/test_attached.py @@ -17,16 +17,16 @@ from bqskit.runtime import get_runtime -@pytest.mark.parametrize('num_workers', [1, -1]) -def test_startup_shutdown_transparently(num_workers: int) -> None: - in_num_childs = len(psutil.Process(os.getpid()).children(recursive=True)) - compiler = Compiler(num_workers=num_workers) - assert compiler.p is not None - compiler.__del__() - if sys.platform == 'win32': - time.sleep(1) - out_num_childs = len(psutil.Process(os.getpid()).children(recursive=True)) - assert in_num_childs == out_num_childs +# @pytest.mark.parametrize('num_workers', [1, -1]) +# def test_startup_shutdown_transparently(num_workers: int) -> None: +# in_num_childs = len(psutil.Process(os.getpid()).children(recursive=True)) +# compiler = Compiler(num_workers=num_workers) +# assert compiler.p is not None +# compiler.__del__() +# if sys.platform == 'win32': +# time.sleep(1) +# out_num_childs = len(psutil.Process(os.getpid()).children(recursive=True)) +# assert in_num_childs == out_num_childs @pytest.mark.parametrize('num_workers', [1, -1]) @@ -60,15 +60,17 @@ def test_create_workers(num_workers: int) -> None: compiler.close() -def test_one_thread_per_worker() -> None: - # On windows we aren't sure how the threads are handeled +def test_two_thread_per_worker() -> None: if sys.platform == 'win32': - return + pytest.skip('Not sure how to count threads on Windows.') + + if sys.platform == 'darwin': + pytest.skip('MacOS requires permissions to count threads.') compiler = Compiler(num_workers=1) assert compiler.p is not None assert len(psutil.Process(compiler.p.pid).children()) in [1, 2] - assert psutil.Process(compiler.p.pid).children()[0].num_threads() == 1 + assert psutil.Process(compiler.p.pid).children()[0].num_threads() == 2 compiler.close() diff --git a/tests/runtime/test_logging.py b/tests/runtime/test_logging.py index bee4bfb0a..4c8694439 100644 --- a/tests/runtime/test_logging.py +++ b/tests/runtime/test_logging.py @@ -2,7 +2,9 @@ from __future__ import annotations import logging +import pickle from io import StringIO +from typing import Any import pytest @@ -143,6 +145,55 @@ def test_using_external_logging(server_compiler: Compiler) -> None: logger.setLevel(logging.WARNING) +class ExternalWithArgsPass(BasePass): + async def run(self, circuit: Circuit, data: PassData) -> None: + logging.getLogger('dummy2').debug('int %d', 1) + + +def test_external_logging_with_args(server_compiler: Compiler) -> None: + logger = logging.getLogger('dummy2') + logger.setLevel(logging.DEBUG) + handler = logging.StreamHandler(StringIO()) + handler.setLevel(logging.DEBUG) + logger.addHandler(handler) + server_compiler.compile(Circuit(1), [ExternalWithArgsPass()]) + log = handler.stream.getvalue() + assert 'int 1' in log + logger.removeHandler(handler) + logger.setLevel(logging.WARNING) + + +class NonSerializable: + def __reduce__(self) -> str | tuple[Any, ...]: + raise pickle.PicklingError('This class is not serializable') + + def __str__(self) -> str: + return 'NonSerializable' + + +class ExternalWithNonSerializableArgsPass(BasePass): + async def run(self, circuit: Circuit, data: PassData) -> None: + logging.getLogger('dummy2').debug( + 'NonSerializable %s', + NonSerializable(), + ) + + +def test_external_logging_with_nonserializable_args( + server_compiler: Compiler, +) -> None: + logger = logging.getLogger('dummy2') + logger.setLevel(logging.DEBUG) + handler = logging.StreamHandler(StringIO()) + handler.setLevel(logging.DEBUG) + logger.addHandler(handler) + server_compiler.compile(Circuit(1), [ExternalWithNonSerializableArgsPass()]) + log = handler.stream.getvalue() + assert 'NonSerializable NonSerializable' in log + logger.removeHandler(handler) + logger.setLevel(logging.WARNING) + + @pytest.mark.parametrize('level', [-1, 0, 1, 2, 3, 4]) def test_limiting_nested_calls_enable_logging( server_compiler: Compiler, diff --git a/tests/runtime/test_next.py b/tests/runtime/test_next.py index 30642a7d7..81c45da5a 100644 --- a/tests/runtime/test_next.py +++ b/tests/runtime/test_next.py @@ -29,7 +29,7 @@ async def run(self, circuit: Circuit, data: PassData) -> None: class TestNoDuplicateResultsInTwoNexts(BasePass): async def run(self, circuit: Circuit, data: PassData) -> None: - future = get_runtime().map(sleepi, [0.3, 0.4, 0.1, 0.2]) + future = get_runtime().map(sleepi, [0.3, 0.4, 0.1, 0.2, 5]) seen = [0] int_ids = await get_runtime().next(future) diff --git a/tests/utils/test_math.py b/tests/utils/test_math.py index 0f21167dd..2e9036b51 100644 --- a/tests/utils/test_math.py +++ b/tests/utils/test_math.py @@ -10,10 +10,13 @@ from scipy.stats import unitary_group from bqskit.qis.pauli import PauliMatrices +from bqskit.qis.pauliz import PauliZMatrices from bqskit.utils.math import canonical_unitary from bqskit.utils.math import dexpmv +from bqskit.utils.math import diagonal_distance from bqskit.utils.math import dot_product from bqskit.utils.math import pauli_expansion +from bqskit.utils.math import pauliz_expansion from bqskit.utils.math import softmax from bqskit.utils.math import unitary_log_no_i @@ -188,6 +191,21 @@ def test_valid(self, reH: npt.NDArray[np.complex128]) -> None: assert np.linalg.norm(H - reH) < 1e-16 +class TestPauliZExpansion: + @pytest.mark.parametrize( + 'reH', + PauliZMatrices(1).paulizs + + PauliZMatrices(2).paulizs + + PauliZMatrices(3).paulizs + + PauliZMatrices(4).paulizs, + ) + def test_valid(self, reH: npt.NDArray[np.complex128]) -> None: + alpha = pauliz_expansion(reH) + print(alpha) + H = PauliZMatrices(int(np.log2(reH.shape[0]))).dot_product(alpha) + assert np.linalg.norm(H - reH) < 1e-16 + + class TestCanonicalUnitary: @pytest.mark.parametrize( 'phase, num_qudits', @@ -206,3 +224,32 @@ def test_canonical_unitary( phased_unitary = phase * base_unitary recanon_unitary = canonical_unitary(phased_unitary) assert np.allclose(canon_unitary, recanon_unitary, atol=1e-5) + + +class TestDiagonalDistance: + @pytest.mark.parametrize( + 'num_qudits, epsilon, threshold_list', + [ + (n, 10 ** -e, [10 ** -t for t in range(1, 10)]) + for n in range(1, 4) + for e in range(1, 10) + ], + ) + def test_diagonal_distance( + self, + num_qudits: int, + epsilon: float, + threshold_list: list[float], + ) -> None: + N = 2 ** num_qudits + off_diag = epsilon / (N - 1) + on_diag = 1 - epsilon + matrix = -off_diag * np.ones((N, N), dtype=np.complex128) + np.fill_diagonal(matrix, on_diag) + + for threshold in threshold_list: + distance = diagonal_distance(matrix) + if epsilon <= threshold: + assert distance <= threshold + else: + assert distance > threshold