Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic PrimitiveJob and BaseSampler/BaseEstimator #9920

Merged
merged 4 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions qiskit/primitives/backend_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _prepare_counts(results: list[Result]):
return counts


class BackendEstimator(BaseEstimator):
class BackendEstimator(BaseEstimator[PrimitiveJob[EstimatorResult]]):
"""Evaluates expectation value using Pauli rotation gates.

The :class:`~.BackendEstimator` class is a generic implementation of the
Expand Down Expand Up @@ -250,7 +250,7 @@ def _run(
observables: tuple[BaseOperator | PauliSumOp, ...],
parameter_values: tuple[tuple[float, ...], ...],
**run_options,
) -> PrimitiveJob:
):
circuit_indices = []
for circuit in circuits:
index = self._circuit_ids.get(_circuit_key(circuit))
Expand Down
4 changes: 2 additions & 2 deletions qiskit/primitives/backend_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .utils import _circuit_key


class BackendSampler(BaseSampler):
class BackendSampler(BaseSampler[PrimitiveJob[SamplerResult]]):
"""A :class:`~.BaseSampler` implementation that provides an interface for
leveraging the sampler interface from any backend.

Expand Down Expand Up @@ -192,7 +192,7 @@ def _run(
circuits: tuple[QuantumCircuit, ...],
parameter_values: tuple[tuple[float, ...], ...],
**run_options,
) -> PrimitiveJob:
):
circuit_indices = []
for circuit in circuits:
index = self._circuit_ids.get(_circuit_key(circuit))
Expand Down
9 changes: 6 additions & 3 deletions qiskit/primitives/base/base_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from abc import abstractmethod
from collections.abc import Sequence
from copy import copy
from typing import Generic, TypeVar

from qiskit.circuit import QuantumCircuit
from qiskit.circuit.parametertable import ParameterView
Expand All @@ -94,8 +95,10 @@
from ..utils import init_observable
from .base_primitive import BasePrimitive

T = TypeVar("T", bound=Job)

class BaseEstimator(BasePrimitive):

class BaseEstimator(BasePrimitive, Generic[T]):
"""Estimator base class.

Base class for Estimator that estimates expectation values of quantum circuits and observables.
Expand Down Expand Up @@ -126,7 +129,7 @@ def run(
observables: Sequence[BaseOperator | PauliSumOp | str] | BaseOperator | PauliSumOp | str,
parameter_values: Sequence[Sequence[float]] | Sequence[float] | float | None = None,
**run_options,
) -> Job:
) -> T:
"""Run the job of the estimation of expectation value(s).

``circuits``, ``observables``, and ``parameter_values`` should have the same
Expand Down Expand Up @@ -193,7 +196,7 @@ def _run(
observables: tuple[SparsePauliOp, ...],
parameter_values: tuple[tuple[float, ...], ...],
**run_options,
) -> Job:
) -> T:
raise NotImplementedError("The subclass of BaseEstimator must implment `_run` method.")

@staticmethod
Expand Down
9 changes: 6 additions & 3 deletions qiskit/primitives/base/base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,18 @@
from abc import abstractmethod
from collections.abc import Sequence
from copy import copy
from typing import Generic, TypeVar

from qiskit.circuit import QuantumCircuit
from qiskit.circuit.parametertable import ParameterView
from qiskit.providers import JobV1 as Job

from .base_primitive import BasePrimitive

T = TypeVar("T", bound=Job)

class BaseSampler(BasePrimitive):

class BaseSampler(BasePrimitive, Generic[T]):
"""Sampler base class

Base class of Sampler that calculates quasi-probabilities of bitstrings from quantum circuits.
Expand All @@ -112,7 +115,7 @@ def run(
circuits: QuantumCircuit | Sequence[QuantumCircuit],
parameter_values: Sequence[float] | Sequence[Sequence[float]] | None = None,
**run_options,
) -> Job:
) -> T:
"""Run the job of the sampling of bitstrings.

Args:
Expand Down Expand Up @@ -153,7 +156,7 @@ def _run(
circuits: tuple[QuantumCircuit, ...],
parameter_values: tuple[tuple[float, ...], ...],
**run_options,
) -> Job:
) -> T:
raise NotImplementedError("The subclass of BaseSampler must implment `_run` method.")

# TODO: validate measurement gates are present
Expand Down
4 changes: 2 additions & 2 deletions qiskit/primitives/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)


class Estimator(BaseEstimator):
class Estimator(BaseEstimator[PrimitiveJob[EstimatorResult]]):
"""
Reference implementation of :class:`BaseEstimator`.

Expand Down Expand Up @@ -127,7 +127,7 @@ def _run(
observables: tuple[BaseOperator | PauliSumOp, ...],
parameter_values: tuple[tuple[float, ...], ...],
**run_options,
) -> PrimitiveJob:
):
circuit_indices = []
for circuit in circuits:
key = _circuit_key(circuit)
Expand Down
9 changes: 7 additions & 2 deletions qiskit/primitives/primitive_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@

import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Generic, TypeVar

from qiskit.providers import JobError, JobStatus, JobV1

from .base.base_result import BasePrimitiveResult

class PrimitiveJob(JobV1):
T = TypeVar("T", bound=BasePrimitiveResult)


class PrimitiveJob(JobV1, Generic[T]):
"""
PrimitiveJob class for the reference implemetations of Primitives.
"""
Expand All @@ -44,7 +49,7 @@ def submit(self):
future = executor.submit(self._function, *self._args, **self._kwargs)
self._future = future

def result(self):
def result(self) -> T:
"""Return the results of the job."""
self._check_submitted()
return self._future.result()
Expand Down
4 changes: 2 additions & 2 deletions qiskit/primitives/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)


class Sampler(BaseSampler):
class Sampler(BaseSampler[PrimitiveJob[SamplerResult]]):
"""
Sampler class.

Expand Down Expand Up @@ -119,7 +119,7 @@ def _run(
circuits: tuple[QuantumCircuit, ...],
parameter_values: tuple[tuple[float, ...], ...],
**run_options,
) -> PrimitiveJob:
):
circuit_indices = []
for circuit in circuits:
key = _circuit_key(circuit)
Expand Down