Skip to content

Commit

Permalink
Fix default batching in variational algorithms (Qiskit#9038)
Browse files Browse the repository at this point in the history
* Fix default batching in variational algorithms

* fix test

* reduce batching to only SPSA

* fix tests

* Apply suggestions from code review

Co-authored-by: Matthew Treinish <[email protected]>

Co-authored-by: Matthew Treinish <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 1, 2022
1 parent 8d7d300 commit f3a0ddc
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 4 deletions.
11 changes: 11 additions & 0 deletions qiskit/algorithms/eigensolvers/vqd.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
from ..exceptions import AlgorithmError
from ..observables_evaluator import estimate_observables

# private function as we expect this to be updated in the next release
from ..utils.set_batching import _set_default_batchsize

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -264,10 +267,18 @@ def compute_eigenvalues(
fun=energy_evaluation, x0=initial_point, bounds=bounds
)
else:
# we always want to submit as many estimations per job as possible for minimal
# overhead on the hardware
was_updated = _set_default_batchsize(self.optimizer)

opt_result = self.optimizer.minimize(
fun=energy_evaluation, x0=initial_point, bounds=bounds
)

# reset to original value
if was_updated:
self.optimizer.set_max_evals_grouped(None)

eval_time = time() - start_time

self._update_vqd_result(result, opt_result, eval_time, self.ansatz.copy())
Expand Down
11 changes: 11 additions & 0 deletions qiskit/algorithms/minimum_eigensolvers/sampling_vqe.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
from ..observables_evaluator import estimate_observables
from ..utils import validate_initial_point, validate_bounds

# private function as we expect this to be updated in the next released
from ..utils.set_batching import _set_default_batchsize


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -208,10 +211,18 @@ def compute_minimum_eigenvalue(
# pylint: disable=not-callable
optimizer_result = self.optimizer(fun=evaluate_energy, x0=initial_point, bounds=bounds)
else:
# we always want to submit as many estimations per job as possible for minimal
# overhead on the hardware
was_updated = _set_default_batchsize(self.optimizer)

optimizer_result = self.optimizer.minimize(
fun=evaluate_energy, x0=initial_point, bounds=bounds
)

# reset to original value
if was_updated:
self.optimizer.set_max_evals_grouped(None)

optimizer_time = time() - start_time

logger.info(
Expand Down
11 changes: 11 additions & 0 deletions qiskit/algorithms/minimum_eigensolvers/vqe.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
from ..observables_evaluator import estimate_observables
from ..utils import validate_initial_point, validate_bounds

# private function as we expect this to be updated in the next released
from ..utils.set_batching import _set_default_batchsize

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -181,10 +184,18 @@ def compute_minimum_eigenvalue(
fun=evaluate_energy, x0=initial_point, jac=evaluate_gradient, bounds=bounds
)
else:
# we always want to submit as many estimations per job as possible for minimal
# overhead on the hardware
was_updated = _set_default_batchsize(self.optimizer)

optimizer_result = self.optimizer.minimize(
fun=evaluate_energy, x0=initial_point, jac=evaluate_gradient, bounds=bounds
)

# reset to original value
if was_updated:
self.optimizer.set_max_evals_grouped(None)

optimizer_time = time() - start_time

logger.info(
Expand Down
9 changes: 6 additions & 3 deletions qiskit/algorithms/optimizers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def __init__(self):
self._bounds_support_level = self.get_support_level()["bounds"]
self._initial_point_support_level = self.get_support_level()["initial_point"]
self._options = {}
self._max_evals_grouped = 1
self._max_evals_grouped = None

@abstractmethod
def get_support_level(self):
Expand All @@ -205,7 +205,7 @@ def set_options(self, **kwargs):

# pylint: disable=invalid-name
@staticmethod
def gradient_num_diff(x_center, f, epsilon, max_evals_grouped=1):
def gradient_num_diff(x_center, f, epsilon, max_evals_grouped=None):
"""
We compute the gradient with the numeric differentiation in the parallel way,
around the point x_center.
Expand All @@ -214,11 +214,14 @@ def gradient_num_diff(x_center, f, epsilon, max_evals_grouped=1):
x_center (ndarray): point around which we compute the gradient
f (func): the function of which the gradient is to be computed.
epsilon (float): the epsilon used in the numeric differentiation.
max_evals_grouped (int): max evals grouped
max_evals_grouped (int): max evals grouped, defaults to 1 (i.e. no batching).
Returns:
grad: the gradient computed
"""
if max_evals_grouped is None: # no batching by default
max_evals_grouped = 1

forig = f(*((x_center,)))
grad = []
ei = np.zeros((len(x_center),), float)
Expand Down
2 changes: 1 addition & 1 deletion qiskit/algorithms/optimizers/spsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ def _batch_evaluate(function, points, max_evals_grouped, unpack_points=False):
"""

# if the function cannot handle lists of points as input, cover this case immediately
if max_evals_grouped == 1:
if max_evals_grouped is None or max_evals_grouped == 1:
# support functions with multiple arguments where the points are given in a tuple
return [
function(*point) if isinstance(point, tuple) else function(point) for point in points
Expand Down
27 changes: 27 additions & 0 deletions qiskit/algorithms/utils/set_batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# This code is part of Qiskit.
#
# (C) Copyright IBM 2022.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Set default batch sizes for the optimizers."""

from qiskit.algorithms.optimizers import Optimizer, SPSA


def _set_default_batchsize(optimizer: Optimizer) -> bool:
"""Set the default batchsize, if None is set and return whether it was updated or not."""
if isinstance(optimizer, SPSA):
updated = optimizer._max_evals_grouped is None
if updated:
optimizer.set_max_evals_grouped(50)
else: # we only set a batchsize for SPSA
updated = False

return updated
29 changes: 29 additions & 0 deletions test/python/algorithms/minimum_eigensolvers/test_vqe.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,35 @@ def run_check():
vqe.optimizer = L_BFGS_B()
run_check()

def test_default_batch_evaluation_on_spsa(self):
"""Test the default batching works."""
ansatz = TwoLocal(2, rotation_blocks=["ry", "rz"], entanglement_blocks="cz")

wrapped_estimator = Estimator()
inner_estimator = Estimator()

callcount = {"estimator": 0}

def wrapped_estimator_run(*args, **kwargs):
kwargs["callcount"]["estimator"] += 1
return inner_estimator.run(*args, **kwargs)

wrapped_estimator.run = partial(wrapped_estimator_run, callcount=callcount)

spsa = SPSA(maxiter=5)

vqe = VQE(wrapped_estimator, ansatz, spsa)
_ = vqe.compute_minimum_eigenvalue(Pauli("ZZ"))

# 1 calibration + 5 loss + 1 return loss
expected_estimator_runs = 1 + 5 + 1

with self.subTest(msg="check callcount"):
self.assertEqual(callcount["estimator"], expected_estimator_runs)

with self.subTest(msg="check reset to original max evals grouped"):
self.assertIsNone(spsa._max_evals_grouped)

def test_batch_evaluate_with_qnspsa(self):
"""Test batch evaluating with QNSPSA works."""
ansatz = TwoLocal(2, rotation_blocks=["ry", "rz"], entanglement_blocks="cz")
Expand Down

0 comments on commit f3a0ddc

Please sign in to comment.