diff --git a/qiskit/primitives/backend_estimator.py b/qiskit/primitives/backend_estimator.py index 35d6cd2eaced..329805bc319e 100644 --- a/qiskit/primitives/backend_estimator.py +++ b/qiskit/primitives/backend_estimator.py @@ -80,7 +80,7 @@ def _prepare_counts(results: list[Result]): return counts -class BackendEstimator(BaseEstimator): +class BackendEstimator(BaseEstimator[PrimitiveJob[EstimatorResult]]): """Evaluates expectation value using Pauli rotation gates. The :class:`~.BackendEstimator` class is a generic implementation of the @@ -250,7 +250,7 @@ def _run( observables: tuple[BaseOperator | PauliSumOp, ...], parameter_values: tuple[tuple[float, ...], ...], **run_options, - ) -> PrimitiveJob: + ): circuit_indices = [] for circuit in circuits: index = self._circuit_ids.get(_circuit_key(circuit)) diff --git a/qiskit/primitives/backend_sampler.py b/qiskit/primitives/backend_sampler.py index c5d66ebd11cd..6d7f16173980 100644 --- a/qiskit/primitives/backend_sampler.py +++ b/qiskit/primitives/backend_sampler.py @@ -30,7 +30,7 @@ from .utils import _circuit_key -class BackendSampler(BaseSampler): +class BackendSampler(BaseSampler[PrimitiveJob[SamplerResult]]): """A :class:`~.BaseSampler` implementation that provides an interface for leveraging the sampler interface from any backend. @@ -192,7 +192,7 @@ def _run( circuits: tuple[QuantumCircuit, ...], parameter_values: tuple[tuple[float, ...], ...], **run_options, - ) -> PrimitiveJob: + ): circuit_indices = [] for circuit in circuits: index = self._circuit_ids.get(_circuit_key(circuit)) diff --git a/qiskit/primitives/base/base_estimator.py b/qiskit/primitives/base/base_estimator.py index 1cac61331b86..f091bc0d335c 100644 --- a/qiskit/primitives/base/base_estimator.py +++ b/qiskit/primitives/base/base_estimator.py @@ -83,6 +83,7 @@ from abc import abstractmethod from collections.abc import Sequence from copy import copy +from typing import Generic, TypeVar from qiskit.circuit import QuantumCircuit from qiskit.circuit.parametertable import ParameterView @@ -94,8 +95,10 @@ from ..utils import init_observable from .base_primitive import BasePrimitive +T = TypeVar("T", bound=Job) -class BaseEstimator(BasePrimitive): + +class BaseEstimator(BasePrimitive, Generic[T]): """Estimator base class. Base class for Estimator that estimates expectation values of quantum circuits and observables. @@ -126,7 +129,7 @@ def run( observables: Sequence[BaseOperator | PauliSumOp | str] | BaseOperator | PauliSumOp | str, parameter_values: Sequence[Sequence[float]] | Sequence[float] | float | None = None, **run_options, - ) -> Job: + ) -> T: """Run the job of the estimation of expectation value(s). ``circuits``, ``observables``, and ``parameter_values`` should have the same @@ -193,7 +196,7 @@ def _run( observables: tuple[SparsePauliOp, ...], parameter_values: tuple[tuple[float, ...], ...], **run_options, - ) -> Job: + ) -> T: raise NotImplementedError("The subclass of BaseEstimator must implment `_run` method.") @staticmethod diff --git a/qiskit/primitives/base/base_sampler.py b/qiskit/primitives/base/base_sampler.py index 769343477ffe..6a78c0767fd7 100644 --- a/qiskit/primitives/base/base_sampler.py +++ b/qiskit/primitives/base/base_sampler.py @@ -78,6 +78,7 @@ from abc import abstractmethod from collections.abc import Sequence from copy import copy +from typing import Generic, TypeVar from qiskit.circuit import QuantumCircuit from qiskit.circuit.parametertable import ParameterView @@ -85,8 +86,10 @@ from .base_primitive import BasePrimitive +T = TypeVar("T", bound=Job) -class BaseSampler(BasePrimitive): + +class BaseSampler(BasePrimitive, Generic[T]): """Sampler base class Base class of Sampler that calculates quasi-probabilities of bitstrings from quantum circuits. @@ -112,7 +115,7 @@ def run( circuits: QuantumCircuit | Sequence[QuantumCircuit], parameter_values: Sequence[float] | Sequence[Sequence[float]] | None = None, **run_options, - ) -> Job: + ) -> T: """Run the job of the sampling of bitstrings. Args: @@ -153,7 +156,7 @@ def _run( circuits: tuple[QuantumCircuit, ...], parameter_values: tuple[tuple[float, ...], ...], **run_options, - ) -> Job: + ) -> T: raise NotImplementedError("The subclass of BaseSampler must implment `_run` method.") # TODO: validate measurement gates are present diff --git a/qiskit/primitives/estimator.py b/qiskit/primitives/estimator.py index ba251ce294d7..7333e44332c9 100644 --- a/qiskit/primitives/estimator.py +++ b/qiskit/primitives/estimator.py @@ -36,7 +36,7 @@ ) -class Estimator(BaseEstimator): +class Estimator(BaseEstimator[PrimitiveJob[EstimatorResult]]): """ Reference implementation of :class:`BaseEstimator`. @@ -127,7 +127,7 @@ def _run( observables: tuple[BaseOperator | PauliSumOp, ...], parameter_values: tuple[tuple[float, ...], ...], **run_options, - ) -> PrimitiveJob: + ): circuit_indices = [] for circuit in circuits: key = _circuit_key(circuit) diff --git a/qiskit/primitives/primitive_job.py b/qiskit/primitives/primitive_job.py index f8d295b662b5..95cc9629e0cd 100644 --- a/qiskit/primitives/primitive_job.py +++ b/qiskit/primitives/primitive_job.py @@ -15,11 +15,16 @@ import uuid from concurrent.futures import ThreadPoolExecutor +from typing import Generic, TypeVar from qiskit.providers import JobError, JobStatus, JobV1 +from .base.base_result import BasePrimitiveResult -class PrimitiveJob(JobV1): +T = TypeVar("T", bound=BasePrimitiveResult) + + +class PrimitiveJob(JobV1, Generic[T]): """ PrimitiveJob class for the reference implemetations of Primitives. """ @@ -44,7 +49,7 @@ def submit(self): future = executor.submit(self._function, *self._args, **self._kwargs) self._future = future - def result(self): + def result(self) -> T: """Return the results of the job.""" self._check_submitted() return self._future.result() diff --git a/qiskit/primitives/sampler.py b/qiskit/primitives/sampler.py index 534f1605afe3..82c637b5d2ba 100644 --- a/qiskit/primitives/sampler.py +++ b/qiskit/primitives/sampler.py @@ -35,7 +35,7 @@ ) -class Sampler(BaseSampler): +class Sampler(BaseSampler[PrimitiveJob[SamplerResult]]): """ Sampler class. @@ -119,7 +119,7 @@ def _run( circuits: tuple[QuantumCircuit, ...], parameter_values: tuple[tuple[float, ...], ...], **run_options, - ) -> PrimitiveJob: + ): circuit_indices = [] for circuit in circuits: key = _circuit_key(circuit)