From 89535a4fe56b9edb217fd2d5e86013a2766a50d0 Mon Sep 17 00:00:00 2001 From: David Wierichs Date: Mon, 19 Jun 2023 19:57:41 +0200 Subject: [PATCH 1/5] Support `HardwareHamiltonian` pulses in `stoch_pulse_grad` (#4215) * single out gradient transform checks * rename stochastic pulse gradient file * unify gradient_analysis and grad_method_validation * continue restructure of analysis+validation * CV * black * modularize more * more modularizing * black * tiny [skip ci] * [skip ci] lint * remove dummy test * test fix * add test file to linting test file * test fixes, docstrings * code review * docstring gradient_analysis_and_grad_method_validation * move first fun * code review:move functions * test regex * regexs * move and promote reorder_grads * tmp * more tmp * test cases, contractions * lint * docstring * even more tmp * cleanup * black * tmp * lint * move stoch_pulse_gradient.. files back to pulse_gradient... * move stoch_pulse_gradient.. files back to pulse_gradient... * lint * rename * extend functions and tests * lint and black * changelog * improve * update example to include non-Pauli word generator * add jit test with pauli sentence * tmp * debugging, docstring, extend test * review * optimize for Pauli words * Apply suggestions from code review Co-authored-by: Romain Moyard * test cases code review * fix parametrization * drafting * working prototype * finish merge; cleanup * changelog * comments * [skip ci] * raising an error; cleanup [skip ci] * Apply suggestions from code review Co-authored-by: Korbinian Kottmann <43949391+Qottmann@users.noreply.github.com> * change contraction idea * typo in docs * tests * remove prints * fix test * test descriptions * fix merge * format * code review; test coverage * coverage reordering * fix * trigger CI * trigger * clear caches * trigger * trigger --------- Co-authored-by: Romain Moyard Co-authored-by: Korbinian Kottmann <43949391+Qottmann@users.noreply.github.com> Co-authored-by: Korbinian Kottmann --- doc/releases/changelog-0.31.0.md | 4 +- pennylane/gradients/pulse_gradient.py | 259 +++++++-- tests/gradients/core/test_pulse_gradient.py | 604 +++++++++++++++++++- 3 files changed, 823 insertions(+), 44 deletions(-) diff --git a/doc/releases/changelog-0.31.0.md b/doc/releases/changelog-0.31.0.md index 82a07004fcd..21d1511dea6 100644 --- a/doc/releases/changelog-0.31.0.md +++ b/doc/releases/changelog-0.31.0.md @@ -97,8 +97,10 @@ [(4216)](https://github.com/PennyLaneAI/pennylane/pull/4216) * The stochastic parameter-shift gradient transform for pulses, `stoch_pulse_grad`, now - supports arbitrary Hermitian generating terms in pulse Hamiltonians. + supports arbitrary Hermitian terms in pulse Hamiltonians. It now also supports + pulses generated by `HardwareHamiltonian`. [(4132)](https://github.com/PennyLaneAI/pennylane/pull/4132) + [(4215)](https://github.com/PennyLaneAI/pennylane/pull/4215) * `DiagonalQubitUnitary` now decomposes into `RZ`, `IsingZZ` and `MultiRZ` gates instead of a `QubitUnitary` operation with a dense matrix. diff --git a/pennylane/gradients/pulse_gradient.py b/pennylane/gradients/pulse_gradient.py index c89d710b8c0..4dbc27ae4a5 100644 --- a/pennylane/gradients/pulse_gradient.py +++ b/pennylane/gradients/pulse_gradient.py @@ -19,7 +19,7 @@ import numpy as np import pennylane as qml -from pennylane.pulse import ParametrizedEvolution +from pennylane.pulse import ParametrizedEvolution, HardwareHamiltonian from .parameter_shift import _make_zero_rep from .general_shift_rules import eigvals_to_frequencies, generate_shift_rule @@ -177,8 +177,8 @@ def _parshift_and_integrate( cjacs (tensor_like): classical Jacobian evaluated at the splitting times int_prefactor (float): prefactor of the numerical integration, corresponding to the size of the time range divided by the number of splitting time samples - psr_coeffs (tensor_like): Coefficients of the parameter-shift rule to contract the results - with before integrating numerically. + psr_coeffs (tensor_like or tuple[tensor_like]): Coefficients of the parameter-shift + rule to contract the results with before integrating numerically. single_measure (bool): Whether the results contain a single measurement per shot setting has_partitioned_shots (bool): Whether the results have a shot vector axis use_broadcasting (bool): Whether broadcasting was used in the tapes that returned the @@ -187,29 +187,79 @@ def _parshift_and_integrate( tensor_like or tuple[tensor_like] or tuple[tuple[tensor_like]]: Gradient entry """ - if use_broadcasting: + def _contract(coeffs, res, cjac): + """Contract three tensors, the first two like a standard matrix multiplication + and the result with the third tensor along the first axes.""" + return jnp.tensordot(jnp.tensordot(coeffs, res, axes=1), cjac, axes=[[0], [0]]) + + if isinstance(psr_coeffs, tuple): + num_shifts = [len(c) for c in psr_coeffs] def _psr_and_contract(res_list, cjacs, int_prefactor): - # Stack results and slice away the first and last values, corresponding to the initial - # condition and the final value of the time evolution, but not to a splitting time - res = qml.math.stack(res_list)[:, 1:-1] - # Contract the results with the parameter-shift rule coefficients - parshift = qml.math.tensordot(psr_coeffs, res, axes=1) - return qml.math.tensordot(parshift, cjacs, axes=[[0], [0]]) * int_prefactor + """Execute the parameter-shift rule and contract with classical Jacobians. + This function assumes multiple generating terms for the pulse parameter + of interest""" + res = jnp.stack(res_list) + idx = 0 + + # Preprocess the results: Reshape, create slices for different generating terms + if use_broadcasting: + # Slice the results according to the different generating terms. Slice away the + # first and last value for each term, which correspond to the initial condition + # and the final value of the time evolution, but not to splitting times + res = tuple(res[idx : (idx := idx + n), 1:-1] for n in num_shifts) + else: + shape = jnp.shape(res) + num_taus = shape[0] // sum(num_shifts) + # Reshape the slices of the results corresponding to different generating terms. + # Afterwards the first axis corresponds to the splitting times and the second axis + # corresponds to the different shifts of the respective term. + # Finally move the shifts-axis to the first position of each term. + res = tuple( + jnp.moveaxis( + jnp.reshape( + res[idx : (idx := idx + n * num_taus)], (num_taus, n) + shape[1:] + ), + 1, + 0, + ) + for n in num_shifts + ) + + # Contract the results, parameter-shift rule coefficients and (classical) Jacobians, + # and include the rescaling factor from the Monte Carlo integral and from global + # prefactors of Pauli word generators. + diff_per_term = jnp.array( + [_contract(c, r, cjac) for c, r, cjac in zip(psr_coeffs, res, cjacs)] + ) + return qml.math.sum(diff_per_term, axis=0) * int_prefactor else: num_shifts = len(psr_coeffs) def _psr_and_contract(res_list, cjacs, int_prefactor): - res = qml.math.stack(res_list) - # Reshape the results such that the first axis corresponds to the splitting times - # and the second axis corresponds to different shifts. All other axes are untouched - shape = qml.math.shape(res) - new_shape = (shape[0] // num_shifts, num_shifts) + shape[1:] - res = qml.math.reshape(res, new_shape) - # Contract the results with the parameter-shift rule coefficients - parshift = qml.math.tensordot(psr_coeffs, res, axes=[[0], [1]]) - return qml.math.tensordot(parshift, cjacs, axes=[[0], [0]]) * int_prefactor + """Execute the parameter-shift rule and contract with classical Jacobians. + This function assumes a single generating term for the pulse parameter + of interest""" + res = jnp.stack(res_list) + + # Preprocess the results: Reshape, create slices for different generating terms + if use_broadcasting: + # Slice away the first and last values, corresponding to the initial condition + # and the final value of the time evolution, but not to splitting times + res = res[:, 1:-1] + else: + # Reshape the results such that the first axis corresponds to the splitting times + # and the second axis corresponds to different shifts. All other axes are untouched. + # Afterwards move the shifts-axis to the first position. + shape = jnp.shape(res) + new_shape = (shape[0] // num_shifts, num_shifts) + shape[1:] + res = jnp.moveaxis(jnp.reshape(res, new_shape), 1, 0) + + # Contract the results, parameter-shift rule coefficients and (classical) Jacobians, + # and include the rescaling factor from the Monte Carlo integral and from global + # prefactors of Pauli word generators. + return _contract(psr_coeffs, res, cjacs) * int_prefactor nesting_layers = (not single_measure) + has_partitioned_shots if nesting_layers == 1: @@ -263,13 +313,13 @@ def _stoch_pulse_grad( \bra{\psi^{(\pm)}_{j}(\boldsymbol{v}, \tau)} B \ket{\psi^{(\pm)}_{j}(\boldsymbol{v}, \tau)} \\ \ket{\psi^{(\pm)}_{j}(\boldsymbol{v}, \tau)} - &= U_{\boldsymbol{v}}(T, \tau) e^{-i \pm \frac{\pi}{2} H_j} + &= U_{\boldsymbol{v}}(T, \tau) e^{-i (\pm \frac{\pi}{4}) H_j} U_{\boldsymbol{v}}(\tau, 0)\ket{\psi_0}. That is, the :math:`j`\ th modified time evolution in these circuit interrupts the evolution generated by the pulse Hamiltonian by inserting a rotation gate generated by the corresponding Hamiltonian term :math:`H_j` with a rotation angle of - :math:`\pm\frac{\pi}{2}`. + :math:`\pm\frac{\pi}{4}`. See below for a more detailed description. The integral in the first equation above is estimated numerically in the stochastic parameter-shift rule. For this, it samples @@ -590,26 +640,49 @@ def ansatz(params): return _expval_stoch_pulse_grad(tape, argnum, num_split_times, key, shots, use_broadcasting) -def _generate_tapes_and_cjacs(tape, idx, key, num_split_times, use_broadcasting): +def _generate_tapes_and_cjacs( + tape, operation, key, num_split_times, use_broadcasting, par_idx=None +): """Generate the tapes and compute the classical Jacobians for one given generating Hamiltonian term of one pulse. - """ - # Obtain the operation into which the indicated parameter feeds, its position in the tape, - # and the index of the parameter within the operation - op, op_idx, term_idx = tape.get_operation(idx) - if not isinstance(op, ParametrizedEvolution): - # Only pulses are supported - raise ValueError( - "stoch_pulse_grad does not support differentiating parameters of " - "other operations than pulses." - ) + Args: + tape (QuantumScript): Tape for which to compute the stochastic pulse parameter-shift + gradient tapes. + operation (tuple[Operation, int, int]): Information about the pulse operation to be + shifted. The first entry is the operation itself, the second entry is its position + in the ``tape``, and the third entry is the index of the differentiated parameter + (and generating term) within the ``HardwareHamiltonian`` of the operation. + key (tuple[int]): Randomness key to create spliting times. + num_split_times (int): Number of splitting times at which to create shifted tapes for + the stochastic shift rule. + use_broadcasting (bool): Whether to use broadcasting in the shift rule or not. + + Returns: + list[QuantumScript]: Gradient tapes for the indicated operation and Hamiltonian term. + list[tensor_like]: Classical Jacobian at the splitting times for the given parameter. + float: Prefactor for the Monte Carlo estimate of the integral in the stochastic shift rule. + tensor_like: Parameter-shift coefficients for the shift rule of the indicated term. + """ + op, op_idx, term_idx = operation coeff, ob = op.H.coeffs_parametrized[term_idx], op.H.ops_parametrized[term_idx] - cjac_fn = jax.jacobian(coeff, argnums=0) + if par_idx is None: + cjac_fn = jax.jacobian(coeff, argnums=0) + else: + # For `par_idx is not None`, we need to extract the entry of the coefficient + # Jacobian that belongs to the parameter of interest. This only happens when + # more than one parameter effectively feeds into one coefficient (HardwareHamiltonian) + + def cjac_fn(params, t): + return jax.jacobian(coeff, argnums=0)(params, t)[par_idx] t0, *_, t1 = op.t taus = jnp.sort(jax.random.uniform(key, shape=(num_split_times,)) * (t1 - t0) + t0) - cjacs = [cjac_fn(op.data[term_idx], tau) for tau in taus] + if isinstance(op.H, HardwareHamiltonian): + op_data = op.H.reorder_fn(op.data, op.H.coeffs_parametrized) + else: + op_data = op.data + cjacs = [cjac_fn(op_data[term_idx], tau) for tau in taus] if use_broadcasting: split_evolve_ops, psr_coeffs = _split_evol_ops(op, ob, taus) tapes = _split_evol_tape(tape, split_evolve_ops, op_idx) @@ -618,8 +691,96 @@ def _generate_tapes_and_cjacs(tape, idx, key, num_split_times, use_broadcasting) for tau in taus: split_evolve_ops, psr_coeffs = _split_evol_ops(op, ob, tau) tapes.extend(_split_evol_tape(tape, split_evolve_ops, op_idx)) - avg_prefactor = (t1 - t0) / num_split_times - return cjacs, tapes, avg_prefactor, psr_coeffs + int_prefactor = (t1 - t0) / num_split_times + return tapes, cjacs, int_prefactor, psr_coeffs + + +def _tapes_data_hardware(tape, operation, key, num_split_times, use_broadcasting): + """Create tapes and gradient data for a trainable parameter of a HardwareHamiltonian, + taking into account its reordering function. + + Args: + tape (QuantumScript): Tape for which to compute the stochastic pulse parameter-shift + gradient tapes. + operation (tuple[Operation, int, int]): Information about the pulse operation to be + shifted. The first entry is the operation itself, the second entry is its position + in the ``tape``, and the third entry is the index of the differentiated parameter + within the ``HardwareHamiltonian`` of the operation. + key (tuple[int]): Randomness key to create spliting times in ``_generate_tapes_and_cjacs`` + num_split_times (int): Number of splitting times at which to create shifted tapes for + the stochastic shift rule. + use_broadcasting (bool): Whether to use broadcasting in the shift rule or not. + + Returns: + list[QuantumScript]: Gradient tapes for the indicated operation and Hamiltonian term. + tuple: Gradient postprocessing data. + See comment below. + + This function analyses the ``reorder_fn`` of the ``HardwareHamiltonian`` of the pulse + that is being differentiated. Given a ``term_idx``, the index of the parameter + in the Hamiltonian, stochastic parameter shift tapes are created for all terms in the + Hamiltonian into which the parameter feeds. While this is a one-to-one relation for + standard ``ParametrizedHamiltonian`` objects, the reordering function of + the ``HardwareHamiltonian`` requires to create tapes for multiple Hamiltonian terms, + and for each term ``_generate_tapes_and_cjacs`` is called. + + The returned gradient data has four entries: + + 1. ``int``: Total number of tapes created for all the terms that depend on the indicated + parameter. + 2. ``tuple[tensor_like]``: Classical Jacobians for all terms and splitting times + 3. ``float``: Prefactor for the Monte Carlo estimate of the integral in the stochastic + shift rule. + 4. ``tuple[tensor_like]``: Parameter-shift coefficients for all terms. + + The tuple axes in the second and fourth entry correspond to the different terms in the + Hamiltonian. + """ + op, op_idx, term_idx = operation + # Map a simple enumeration of numbers from HardwareHamiltonian input parameters to + # ParametrizedHamiltonian parameters. This is typically a fan-out function. + fake_params, allowed_outputs = np.arange(op.num_params), set(range(op.num_params)) + reordered = op.H.reorder_fn(fake_params, op.H.coeffs_parametrized) + + def _raise(): + raise ValueError( + "Only permutations, fan-out or fan-in functions are allowed as reordering functions " + "in HardwareHamiltonians treated by stoch_pulse_grad. The reordering function of " + f"{op.H} mapped {fake_params} to {reordered}." + ) + + cjacs, tapes, psr_coeffs = [], [], [] + for coeff_idx, x in enumerate(reordered): + # Find out whether the value term_idx, corresponding to the current parameter of interest, + # has been mapped to x (for scalar x) or into x (for 1d x). If so, generate tapes and data + # Also check that only allowed outputs have been produced by the reordering function. + if not hasattr(x, "__len__"): + if x not in allowed_outputs: + _raise() + if x != term_idx: + continue + cjac_idx = None + else: + if not all(_x in list(range(op.num_params)) for _x in x): + _raise() + if term_idx not in x: + continue + cjac_idx = np.argwhere([_x == term_idx for _x in x])[0][0] + + _operation = (op, op_idx, coeff_idx) + # Overwriting int_prefactor does not matter, it is equal for all parameters in this op, + # because it only consists of the duration `op.t[-1]-op.t[0]` and `num_split_times` + _tapes, _cjacs, int_prefactor, _psr_coeffs = _generate_tapes_and_cjacs( + tape, _operation, key, num_split_times, use_broadcasting, cjac_idx + ) + cjacs.append(qml.math.stack(_cjacs)) + tapes.extend(_tapes) + psr_coeffs.append(_psr_coeffs) + + # The fact that psr_coeffs are a tuple only for hardware Hamiltonian generators will be + # used in `_parshift_and_integrate`. + data = (len(tapes), tuple(cjacs), int_prefactor, tuple(psr_coeffs)) + return tapes, data # pylint: disable=too-many-arguments @@ -637,12 +798,26 @@ def _expval_stoch_pulse_grad(tape, argnum, num_split_times, key, shots, use_broa continue key, _key = jax.random.split(key) - cjacs, _tapes, avg_prefactor, psr_coeffs = _generate_tapes_and_cjacs( - tape, idx, _key, num_split_times, use_broadcasting - ) + operation = tape.get_operation(idx) + op, *_ = operation + if not isinstance(op, ParametrizedEvolution): + raise ValueError( + "stoch_pulse_grad does not support differentiating parameters of " + "other operations than pulses." + ) + if isinstance(op.H, HardwareHamiltonian): + # Treat HardwareHamiltonians separately because they have a reordering function + _tapes, data = _tapes_data_hardware( + tape, operation, key, num_split_times, use_broadcasting + ) + else: + _tapes, cjacs, int_prefactor, psr_coeffs = _generate_tapes_and_cjacs( + tape, operation, _key, num_split_times, use_broadcasting + ) + data = (len(_tapes), qml.math.stack(cjacs), int_prefactor, psr_coeffs) - gradient_data.append((len(_tapes), qml.math.stack(cjacs), avg_prefactor, psr_coeffs)) tapes.extend(_tapes) + gradient_data.append(data) num_measurements = len(tape.measurements) single_measure = num_measurements == 1 @@ -653,7 +828,7 @@ def _expval_stoch_pulse_grad(tape, argnum, num_split_times, key, shots, use_broa def processing_fn(results): start = 0 grads = [] - for num_tapes, cjacs, avg_prefactor, psr_coeffs in gradient_data: + for num_tapes, cjacs, int_prefactor, psr_coeffs in gradient_data: if num_tapes == 0: grads.append(None) continue @@ -664,7 +839,7 @@ def processing_fn(results): g = _parshift_and_integrate( res, cjacs, - avg_prefactor, + int_prefactor, psr_coeffs, single_measure, has_partitioned_shots, diff --git a/tests/gradients/core/test_pulse_gradient.py b/tests/gradients/core/test_pulse_gradient.py index 38347846fce..080411d6a59 100644 --- a/tests/gradients/core/test_pulse_gradient.py +++ b/tests/gradients/core/test_pulse_gradient.py @@ -23,6 +23,7 @@ from pennylane.gradients.general_shift_rules import eigvals_to_frequencies, generate_shift_rule from pennylane.gradients.pulse_gradient import ( + _parshift_and_integrate, _split_evol_ops, _split_evol_tape, stoch_pulse_grad, @@ -264,6 +265,459 @@ def test_with_parametrized_evolution(self): assert qml.equal(t.operations[-1], ops[2]) +@pytest.mark.jax +class TestParshiftAndIntegrate: + """Test the helper routine ``_parshift_and_integrate``. Most tests use uniform + return types and parameters, so that we can test against simple tensor contractions.""" + + # pylint: disable=too-many-arguments + + @pytest.mark.parametrize("multi_term", [1, 4]) + @pytest.mark.parametrize("meas_shape", [(), (4,)]) + @pytest.mark.parametrize("par_shape", [(), (3,), (2, 7)]) + @pytest.mark.parametrize("num_shifts", [2, 5]) + @pytest.mark.parametrize("num_split_times", [1, 7]) + def test_single_measure_single_shots( + self, num_split_times, num_shifts, par_shape, multi_term, meas_shape + ): + """Test that ``_parshift_and_integrate`` works with results for a single measurement + per shift and splitting time, and with a single setting of shots. This corresponds to + ``single_measure=True and has_partitioned_shots=False``. The test is parametrized with whether + or not there are multiple Hamiltonian terms to take into account (and sum their + contributions), with the shape of the single measurement and of the parameter, with + the number of shifts in the shift rule and with the number of splitting times. + """ + from jax import numpy as jnp + + np.random.seed(3751) + + cjac_shape = (num_split_times,) + par_shape + if multi_term > 1: + cjacs = tuple(np.random.random(cjac_shape) for _ in range(multi_term)) + psr_coeffs = tuple(np.random.random(num_shifts) for _ in range(multi_term)) + else: + cjacs = np.random.random(cjac_shape) + psr_coeffs = np.random.random(num_shifts) + + results_shape = (num_split_times * num_shifts * multi_term,) + meas_shape + new_results_shape = ( + multi_term, + num_split_times, + num_shifts, + ) + meas_shape + results = np.random.random(results_shape) + + prefactor = 0.3214 + + res = _parshift_and_integrate( + results, + cjacs, + prefactor, + psr_coeffs, + single_measure=True, + has_partitioned_shots=False, + use_broadcasting=False, + ) + + assert isinstance(res, jnp.ndarray) + assert res.shape == meas_shape + par_shape + + _results = np.reshape(results, new_results_shape) + _cjacs = np.stack(cjacs).reshape((multi_term,) + cjac_shape) + _psr_coeffs = np.stack(psr_coeffs).reshape((multi_term, num_shifts)) + meas_letter = "" if meas_shape == () else "a" + contraction = f"ms,mts{meas_letter},mt...->{meas_letter}..." + expected = np.einsum(contraction, _psr_coeffs, _results, _cjacs) + assert np.allclose(res, expected * prefactor) + + @pytest.mark.parametrize("multi_term", [1, 4]) + @pytest.mark.parametrize("meas_shape", [(), (4,)]) + @pytest.mark.parametrize("par_shape", [(), (3,), (2, 7)]) + @pytest.mark.parametrize("num_shifts", [2, 5]) + @pytest.mark.parametrize("num_split_times", [1, 7]) + def test_single_measure_single_shots_broadcast( + self, num_split_times, num_shifts, par_shape, multi_term, meas_shape + ): + """Test that ``_parshift_and_integrate`` works with results for a single measurement + per shift and splitting time, and with a single setting of shots. This corresponds to + ``single_measure=True and has_partitioned_shots=False``. The test is parametrized with whether + or not there are multiple Hamiltonian terms to take into account (and sum their + contributions), with the shape of the single measurement and of the parameter, with + the number of shifts in the shift rule and with the number of splitting times. + This is the variant of the previous test that uses broadcasting. + """ + from jax import numpy as jnp + + np.random.seed(3751) + + cjac_shape = (num_split_times,) + par_shape + if multi_term > 1: + cjacs = tuple(np.random.random(cjac_shape) for _ in range(multi_term)) + psr_coeffs = tuple(np.random.random(num_shifts) for _ in range(multi_term)) + else: + cjacs = np.random.random(cjac_shape) + psr_coeffs = np.random.random(num_shifts) + + results_shape = (num_shifts * multi_term, (num_split_times + 2)) + meas_shape + new_results_shape = ( + multi_term, + num_shifts, + num_split_times + 2, + ) + meas_shape + results = np.random.random(results_shape) + + prefactor = 0.3214 + + res = _parshift_and_integrate( + results, + cjacs, + prefactor, + psr_coeffs, + single_measure=True, + has_partitioned_shots=False, + use_broadcasting=True, + ) + + assert isinstance(res, jnp.ndarray) + assert res.shape == meas_shape + par_shape + + _results = np.reshape(results, new_results_shape) + _cjacs = np.stack(cjacs).reshape((multi_term,) + cjac_shape) + _psr_coeffs = np.stack(psr_coeffs).reshape((multi_term, num_shifts)) + meas_letter = "" if meas_shape == () else "a" + # Slice away excess results + _results = _results[:, :, 1:-1] + # With broadcasting, the axes of different shifts and splitting times are + # switched for the results tensor, compared to without broadcasting. + contraction = f"ms,mst{meas_letter},mt...->{meas_letter}..." + expected = np.einsum(contraction, _psr_coeffs, _results, _cjacs) + assert np.allclose(res, expected * prefactor) + + @pytest.mark.parametrize("multi_term", [1, 4]) + @pytest.mark.parametrize("meas_shape", [(), (4,)]) + @pytest.mark.parametrize("par_shape", [(), (3,), (2, 2)]) + @pytest.mark.parametrize("num_shifts", [2, 5]) + @pytest.mark.parametrize("num_split_times", [1, 3]) + def test_multi_measure_or_multi_shots( + self, num_split_times, num_shifts, par_shape, multi_term, meas_shape + ): + """Test that ``_parshift_and_integrate`` works with results for multiple measurements + per shift and splitting time and with a single setting of shots, or alternatively with + a single measurement but multiple shot settings. This corresponds to + ``single_measure=False and has_partitioned_shots=False`` or + ``single_measure=True and has_partitioned_shots=True``. The test is parametrized with whether + or not there are multiple Hamiltonian terms to take into account (and sum their + contributions), with the shape of the single measurement and of the parameter, with + the number of shifts in the shift rule and with the number of splitting times. + """ + from jax import numpy as jnp + + np.random.seed(3751) + + num_meas_or_shots = 5 + + cjac_shape = (num_split_times,) + par_shape + if multi_term > 1: + cjacs = tuple(np.random.random(cjac_shape) for _ in range(multi_term)) + psr_coeffs = tuple(np.random.random(num_shifts) for _ in range(multi_term)) + else: + cjacs = np.random.random(cjac_shape) + psr_coeffs = np.random.random(num_shifts) + + results_shape = ( + num_split_times * num_shifts * multi_term, + num_meas_or_shots, + ) + meas_shape + new_results_shape = ( + multi_term, + num_split_times, + num_shifts, + num_meas_or_shots, + ) + meas_shape + results = np.random.random(results_shape) + + prefactor = 0.3214 + + res0, res1 = ( + _parshift_and_integrate( + results, + cjacs, + prefactor, + psr_coeffs, + single_measure=_bool, + has_partitioned_shots=_bool, + use_broadcasting=False, + ) + for _bool in [False, True] + ) + + _results = np.reshape(results, new_results_shape) + _cjacs = np.stack(cjacs).reshape((multi_term,) + cjac_shape) + _psr_coeffs = np.stack(psr_coeffs).reshape((multi_term, num_shifts)) + meas_letter = "" if meas_shape == () else "a" + contraction = f"ms,mtsn{meas_letter},mt...->n{meas_letter}..." + expected = np.einsum(contraction, _psr_coeffs, _results, _cjacs) + + for res in [res0, res1]: + assert isinstance(res, tuple) + assert len(res) == num_meas_or_shots + assert all(isinstance(r, jnp.ndarray) for r in res) + assert all(r.shape == meas_shape + par_shape for r in res) + + assert np.allclose(np.stack(res), expected * prefactor) + + @pytest.mark.parametrize("multi_term", [1, 4]) + @pytest.mark.parametrize("meas_shape", [(), (4,)]) + @pytest.mark.parametrize("par_shape", [(), (3,), (2, 2)]) + @pytest.mark.parametrize("num_shifts", [2, 5]) + @pytest.mark.parametrize("num_split_times", [1, 3]) + def test_multi_measure_or_multi_shots_broadcast( + self, num_split_times, num_shifts, par_shape, multi_term, meas_shape + ): + """Test that ``_parshift_and_integrate`` works with results for multiple measurements + per shift and splitting time and with a single setting of shots, or alternatively with + a single measurement but multiple shot settings. This corresponds to + ``single_measure=False and has_partitioned_shots=False`` or + ``single_measure=True and has_partitioned_shots=True``. The test is parametrized with whether + or not there are multiple Hamiltonian terms to take into account (and sum their + contributions), with the shape of the single measurement and of the parameter, with + the number of shifts in the shift rule and with the number of splitting times. + This is the variant of the previous test that uses broadcasting. + """ + from jax import numpy as jnp + + np.random.seed(3751) + + num_meas_or_shots = 5 + + cjac_shape = (num_split_times,) + par_shape + if multi_term > 1: + cjacs = tuple(np.random.random(cjac_shape) for _ in range(multi_term)) + psr_coeffs = tuple(np.random.random(num_shifts) for _ in range(multi_term)) + else: + cjacs = np.random.random(cjac_shape) + psr_coeffs = np.random.random(num_shifts) + + results_shape = ( + num_shifts * multi_term, + num_meas_or_shots, + (num_split_times + 2), + ) + meas_shape + new_results_shape = ( + multi_term, + num_shifts, + num_meas_or_shots, + num_split_times + 2, + ) + meas_shape + results = np.random.random(results_shape) + + prefactor = 0.3214 + + res0, res1 = ( + _parshift_and_integrate( + results, + cjacs, + prefactor, + psr_coeffs, + single_measure=_bool, + has_partitioned_shots=_bool, + use_broadcasting=True, + ) + for _bool in [False, True] + ) + + _results = np.reshape(results, new_results_shape) + _cjacs = np.stack(cjacs).reshape((multi_term,) + cjac_shape) + _psr_coeffs = np.stack(psr_coeffs).reshape((multi_term, num_shifts)) + meas_letter = "" if meas_shape == () else "a" + # Slice away excess results + _results = _results[:, :, :, 1:-1] + # With broadcasting, the axes of different shifts and splitting times are + # switched for the results tensor, compared to without broadcasting. + contraction = f"ms,msnt{meas_letter},mt...->n{meas_letter}..." + expected = np.einsum(contraction, _psr_coeffs, _results, _cjacs) + + for res in [res0, res1]: + assert isinstance(res, tuple) + assert len(res) == num_meas_or_shots + assert all(isinstance(r, jnp.ndarray) for r in res) + assert all(r.shape == meas_shape + par_shape for r in res) + + assert np.allclose(np.stack(res), expected * prefactor) + + @pytest.mark.parametrize("multi_term", [1, 4]) + @pytest.mark.parametrize("meas_shape", [(), (4,)]) + @pytest.mark.parametrize("par_shape", [(), (3,), (2, 2)]) + @pytest.mark.parametrize("num_shifts", [2, 5]) + @pytest.mark.parametrize("num_split_times", [1, 3]) + def test_multi_measure_multi_shots( + self, num_split_times, num_shifts, par_shape, multi_term, meas_shape + ): + """Test that ``_parshift_and_integrate`` works with results for multiple measurements + per shift and splitting time and with multiple shot settings. This corresponds to + ``single_measure=False and has_partitioned_shots=True``. The test is parametrized with whether + or not there are multiple Hamiltonian terms to take into account (and sum their + contributions), with the shape of the single measurement and of the parameter, with + the number of shifts in the shift rule and with the number of splitting times. + """ + from jax import numpy as jnp + + np.random.seed(3751) + + num_shots = 3 + num_meas = 5 + + cjac_shape = (num_split_times,) + par_shape + if multi_term > 1: + cjacs = tuple(np.random.random(cjac_shape) for _ in range(multi_term)) + psr_coeffs = tuple(np.random.random(num_shifts) for _ in range(multi_term)) + else: + cjacs = np.random.random(cjac_shape) + psr_coeffs = np.random.random(num_shifts) + + results_shape = ( + num_split_times * num_shifts * multi_term, + num_shots, + num_meas, + ) + meas_shape + new_results_shape = ( + multi_term, + num_split_times, + num_shifts, + num_shots, + num_meas, + ) + meas_shape + results = np.random.random(results_shape) + + prefactor = 0.3214 + + res = _parshift_and_integrate( + results, + cjacs, + prefactor, + psr_coeffs, + single_measure=False, + has_partitioned_shots=True, + use_broadcasting=False, + ) + + assert isinstance(res, tuple) + assert len(res) == num_shots + for r in res: + assert isinstance(r, tuple) + assert len(r) == num_meas + assert all(isinstance(_r, jnp.ndarray) for _r in r) + assert all(_r.shape == meas_shape + par_shape for _r in r) + + _results = np.reshape(results, new_results_shape) + _cjacs = np.stack(cjacs).reshape((multi_term,) + cjac_shape) + _psr_coeffs = np.stack(psr_coeffs).reshape((multi_term, num_shifts)) + meas_letter = "" if meas_shape == () else "a" + contraction = f"ms,mtsNn{meas_letter},mt...->Nn{meas_letter}..." + expected = np.einsum(contraction, _psr_coeffs, _results, _cjacs) + assert np.allclose(np.stack(res), expected * prefactor) + + # TODO: Once #2690 is resolved and the corresponding error is removed, + # unskip the following test + @pytest.mark.skip("Broadcasting, shot vector and multi-measurement not supported.") + @pytest.mark.parametrize("multi_term", [1, 4]) + @pytest.mark.parametrize("meas_shape", [(), (4,)]) + @pytest.mark.parametrize("par_shape", [(), (3,), (2, 2)]) + @pytest.mark.parametrize("num_shifts", [2, 5]) + @pytest.mark.parametrize("num_split_times", [1, 3]) + def test_multi_measure_multi_shots_broadcast( + self, num_split_times, num_shifts, par_shape, multi_term, meas_shape + ): + """Test that ``_parshift_and_integrate`` works with results for multiple measurements + per shift and splitting time and with multiple shot settings. This corresponds to + ``single_measure=False and has_partitioned_shots=True``. The test is parametrized with whether + or not there are multiple Hamiltonian terms to take into account (and sum their + contributions), with the shape of the single measurement and of the parameter, with + the number of shifts in the shift rule and with the number of splitting times. + This is the variant of the previous test that uses broadcasting. + """ + from jax import numpy as jnp + + np.random.seed(3751) + + num_shots = 3 + num_meas = 5 + + cjac_shape = (num_split_times,) + par_shape + if multi_term > 1: + cjacs = tuple(np.random.random(cjac_shape) for _ in range(multi_term)) + psr_coeffs = tuple(np.random.random(num_shifts) for _ in range(multi_term)) + else: + cjacs = np.random.random(cjac_shape) + psr_coeffs = np.random.random(num_shifts) + + results_shape = ( + num_shifts * multi_term, + num_shots, + num_meas, + (num_split_times + 2), + ) + meas_shape + new_results_shape = ( + multi_term, + num_shifts, + num_shots, + num_meas, + num_split_times + 2, + ) + meas_shape + results = np.random.random(results_shape) + + prefactor = 0.3214 + + res = _parshift_and_integrate( + results, + cjacs, + prefactor, + psr_coeffs, + single_measure=False, + has_partitioned_shots=True, + use_broadcasting=True, + ) + + assert isinstance(res, tuple) + assert len(res) == num_shots + for r in res: + assert isinstance(r, tuple) + assert len(r) == num_meas + assert all(isinstance(_r, jnp.ndarray) for _r in r) + assert all(_r.shape == meas_shape + par_shape for _r in r) + + _results = np.reshape(results, new_results_shape) + _cjacs = np.stack(cjacs).reshape((multi_term,) + cjac_shape) + _psr_coeffs = np.stack(psr_coeffs).reshape((multi_term, num_shifts)) + meas_letter = "" if meas_shape == () else "a" + # Slice away excess results + _results = _results[:, :, :, :, 1:-1] + # With broadcasting, the axes of different shifts and splitting times are + # switched for the results tensor, compared to without broadcasting. + contraction = f"ms,msNnt{meas_letter},mt...->Nn{meas_letter}..." + expected = np.einsum(contraction, _psr_coeffs, _results, _cjacs) + assert np.allclose(np.stack(res), expected * prefactor) + + # TODO: Once #2690 is resolved and the corresponding error is removed, + # remove the following test + def test_raises_multi_measure_multi_shots_broadcasting(self): + """Test that an error is raised if multiple measurements, a shot vector and broadcasting + all are used simultaneously.""" + + _match = "Broadcasting, multiple measurements and shot vectors are currently" + with pytest.raises(NotImplementedError, match=_match): + # Dummy input values that are barely used before raising the error. + _parshift_and_integrate( + [], + [], + [], + [], + single_measure=False, + has_partitioned_shots=True, + use_broadcasting=True, + ) + + @pytest.mark.jax class TestStochPulseGradErrors: """Test errors raised by stoch_pulse_grad.""" @@ -339,6 +793,26 @@ def test_raises_use_broadcasting_with_broadcasted_tape(self): with pytest.raises(ValueError, match="Broadcasting is not supported for tapes that"): stoch_pulse_grad(tape, use_broadcasting=True) + @pytest.mark.parametrize( + "reorder_fn", + [ + lambda x, _: [x[0] + 10, x[0] - 2], + lambda x, _: [[x[0], x[0] + 10], [x[0] - 2, x[0]]], + ], + ) + def test_raises_for_invalid_reorder_fn(self, reorder_fn): + """Test that an error is raised for an invalid reordering function of + a HardwareHamiltonian.""" + + H = qml.pulse.transmon_drive(qml.pulse.constant, 0.0, 0.0, wires=[0]) + H.reorder_fn = reorder_fn + ops = [qml.evolve(H)([0.152], 0.3)] + tape = qml.tape.QuantumScript(ops, measurements=[qml.expval(qml.PauliZ(0))]) + tape.trainable_params = [0] + _match = "Only permutations, fan-out or fan-in functions are allowed as reordering" + with pytest.raises(ValueError, match=_match): + stoch_pulse_grad(tape) + @pytest.mark.jax class TestStochPulseGrad: @@ -377,6 +851,7 @@ def sine(p, t): def test_all_zero_grads(self, ops, arg, exp_shapes): """Test that a zero gradient is returned when all trainable parameters are identified to have zero gradient in advance.""" + import jax from jax import numpy as jnp arg = None if arg is None else jnp.array(arg) @@ -393,6 +868,7 @@ def test_all_zero_grads(self, ops, arg, exp_shapes): assert all(qml.math.allclose(_r, np.zeros(_sh)) for _r, _sh in zip(r, exp_shape)) else: assert qml.math.allclose(r, np.zeros(exp_shape)) + jax.clear_caches() def test_some_zero_grads(self): """Test that a zero gradient is returned for trainable parameters that are @@ -415,6 +891,7 @@ def test_some_zero_grads(self): assert isinstance(res, tuple) and len(res) == 2 assert qml.math.allclose(res[0][0], np.zeros(5)) assert qml.math.allclose(res[1][0], np.zeros((2, 5))) + jax.clear_caches() @pytest.mark.parametrize("num_split_times", [1, 3]) @pytest.mark.parametrize("t", [2.0, 3, (0.5, 0.6), (0.1, 0.9, 1.2)]) @@ -441,6 +918,7 @@ def test_constant_ry(self, num_split_times, t): res = fn(qml.execute(tapes, dev, None)) assert qml.math.isclose(res, -2 * jnp.sin(2 * p) * delta_t) + jax.clear_caches() @pytest.mark.parametrize("num_split_times", [1, 3]) @pytest.mark.parametrize("t", [2.0, 3, (0.5, 0.6), (0.1, 0.9, 1.2)]) @@ -471,11 +949,13 @@ def test_constant_ry_rescaled(self, num_split_times, t): res = fn(qml.execute(tapes, dev, None)) assert qml.math.isclose(res, -2 * jnp.sin(2 * p) * delta_t * prefactor) + jax.clear_caches() @pytest.mark.parametrize("t", [0.02, (0.5, 0.6)]) def test_sin_envelope_rz_expval(self, t): """Test that the derivative of a pulse with a sine wave envelope is computed correctly when returning an expectation value.""" + import jax import jax.numpy as jnp T = t if isinstance(t, tuple) else (0, t) @@ -508,11 +988,13 @@ def test_sin_envelope_rz_expval(self, t): exp_grad = -2 * jnp.sin(2 * theta) * theta_jac # classical Jacobian is being estimated with the Monte Carlo sampling -> coarse tolerance assert qml.math.allclose(res, exp_grad, atol=0.2) + jax.clear_caches() @pytest.mark.parametrize("t", [0.02, (0.5, 0.6)]) def test_sin_envelope_rx_probs(self, t): """Test that the derivative of a pulse with a sine wave envelope is computed correctly when returning probabilities.""" + import jax import jax.numpy as jnp T = t if isinstance(t, tuple) else (0, t) @@ -547,11 +1029,13 @@ def test_sin_envelope_rx_probs(self, t): exp_jac = jnp.tensordot(probs_jac, theta_jac, axes=0) # classical Jacobian is being estimated with the Monte Carlo sampling -> coarse tolerance assert qml.math.allclose(jac, exp_jac, atol=0.2) + jax.clear_caches() @pytest.mark.parametrize("t", [0.02, (0.5, 0.6)]) def test_sin_envelope_rx_expval_probs(self, t): """Test that the derivative of a pulse with a sine wave envelope is computed correctly when returning expectation.""" + import jax import jax.numpy as jnp T = t if isinstance(t, tuple) else (0, t) @@ -590,11 +1074,13 @@ def test_sin_envelope_rx_expval_probs(self, t): # classical Jacobian is being estimated with the Monte Carlo sampling -> coarse tolerance for j, e in zip(jac, exp_jac): assert qml.math.allclose(j, e, atol=0.2) + jax.clear_caches() @pytest.mark.parametrize("t", [0.02, (0.5, 0.6)]) def test_pwc_envelope_rx(self, t): """Test that the derivative of a pulse generated by a piecewise constant Hamiltonian is computed correctly.""" + import jax import jax.numpy as jnp T = t if isinstance(t, tuple) else (0, t) @@ -618,6 +1104,7 @@ def test_pwc_envelope_rx(self, t): assert qml.math.allclose( res, -2 * jnp.sin(2 * p) * (T[1] - T[0]) / len(params[0]), atol=0.01 ) + jax.clear_caches() @pytest.mark.parametrize("t", [2.0, 3, (0.5, 0.6)]) def test_constant_commuting(self, t): @@ -648,6 +1135,7 @@ def test_constant_commuting(self, t): -2 * jnp.sin(2 * p[1]) * jnp.cos(2 * p[0]) * (T[1] - T[0]), ] assert qml.math.allclose(res, exp_grad) + jax.clear_caches() def test_advanced_pulse(self): """Test the derivative of a more complex pulse.""" @@ -685,6 +1173,7 @@ def qnode(params): res = fn(qml.execute(tapes, dev, None)) exp_grad = jax.grad(qnode)(params) assert all(qml.math.allclose(r, e, rtol=0.4) for r, e in zip(res, exp_grad)) + jax.clear_caches() def test_randomness(self): """Test that the derivative of a pulse is exactly the same when reusing a seed and @@ -723,6 +1212,7 @@ def test_randomness(self): assert res_a_0 == res_a_1 assert not res_a_0 == res_b + jax.clear_caches() def test_two_pulses(self): """Test that the derivatives of two pulses in a circuit are computed correctly.""" @@ -754,6 +1244,7 @@ def qnode(params_0, params_1): exp_grad = jax.grad(qnode, argnums=(0, 1))(params_0, params_1) exp_grad = exp_grad[0] + exp_grad[1] assert all(qml.math.allclose(r, e, rtol=0.4) for r, e in zip(res, exp_grad)) + jax.clear_caches() @pytest.mark.parametrize( "generator, exp_num_tapes, prefactor", @@ -791,6 +1282,7 @@ def fun(params): assert qml.math.isclose(res, -2 * jnp.sin(2 * p) * (T[1] - T[0]) * prefactor) res_jit = jax.jit(fun)(params) assert qml.math.isclose(res, res_jit) + jax.clear_caches() @pytest.mark.parametrize("shots", [None, 100]) def test_shots_attribute(self, shots): @@ -879,6 +1371,7 @@ def circuit(params): p = params[0] * T exp_grad = -2 * jnp.sin(2 * p) * T assert qml.math.allclose(grad, exp_grad, atol=tol, rtol=0.0) + jax.clear_caches() @pytest.mark.parametrize("shots, tol", [(None, 1e-4), (100, 0.1), ([100, 99], 0.1)]) @pytest.mark.parametrize("num_split_times", [1, 2]) @@ -909,6 +1402,7 @@ def circuit(params): p_y = params[1][0] * T_y exp_grad = [[-2 * jnp.sin(2 * (p_x + p_y)) * T_x], [-2 * jnp.sin(2 * (p_x + p_y)) * T_y]] assert qml.math.allclose(grad, exp_grad, atol=tol, rtol=0.0) + jax.clear_caches() @pytest.mark.parametrize("shots, tol", [(None, 1e-4), (100, 0.1), ([100, 99], 0.1)]) @pytest.mark.parametrize("num_split_times", [1, 2]) @@ -935,6 +1429,7 @@ def circuit(params): p = params[0] * T exp_jac = jnp.array([-1, 1]) * jnp.sin(2 * p) * T assert qml.math.allclose(jac, exp_jac, atol=tol, rtol=0.0) + jax.clear_caches() @pytest.mark.parametrize("shots, tol", [(None, 1e-4), (100, 0.1), ([100, 100], 0.1)]) @pytest.mark.parametrize("num_split_times", [1, 2]) @@ -967,6 +1462,7 @@ def circuit(params): else: for j, e in zip(jac, exp_jac): assert qml.math.allclose(j[0], e, atol=tol, rtol=0.0) + jax.clear_caches() @pytest.mark.xfail @pytest.mark.parametrize("num_split_times", [1, 2]) @@ -993,6 +1489,7 @@ def circuit(params, T=None): exp_grad = -2 * jnp.sin(2 * p) * T jit_grad = jax.jit(jax.grad(circuit))(params, T=T) assert qml.math.isclose(jit_grad, exp_grad) + jax.clear_caches() @pytest.mark.slow def test_advanced_qnode(self): @@ -1033,8 +1530,9 @@ def ansatz(params): assert all( qml.math.allclose(r, e, rtol=0.4) for r, e in zip(grad_pulse_grad, grad_backprop) ) + jax.clear_caches() - def test_multi_return_broadcasting_shot_vector_raises(self): + def test_multi_return_broadcasting_multi_shots_raises(self): """Test that a simple qnode that returns an expectation value and probabilities can be differentiated with stoch_pulse_grad with use_broadcasting.""" import jax @@ -1060,6 +1558,7 @@ def circuit(params): params = [jnp.array(0.4)] with pytest.raises(NotImplementedError, match="Broadcasting, multiple measurements and"): jax.jacobian(circuit)(params) + jax.clear_caches() # TODO: delete error test above and uncomment the following test case once #2690 is resolved. @pytest.mark.parametrize("shots, tol", [(None, 1e-4), (100, 0.1)]) # , ([100, 100], 0.1)]) @@ -1097,6 +1596,7 @@ def circuit(params): else: for j, e in zip(jac, exp_jac): assert qml.math.allclose(j[0], e, atol=tol, rtol=0.0) + jax.clear_caches() @pytest.mark.parametrize("num_split_times", [1, 2]) def test_broadcasting_coincides_with_nonbroadcasting(self, num_split_times): @@ -1141,6 +1641,108 @@ def ansatz(params): jac_no_bc = jax.jacobian(circuit_no_bc)(params) for j0, j1 in zip(jac_bc, jac_no_bc): assert qml.math.allclose(j0, j1) + jax.clear_caches() + + def test_with_drive_exact(self): + """Test that a HardwareHamiltonian only containing a drive is differentiated correctly + for a constant amplitude and zero frequency and phase.""" + import jax + + timespan = 0.4 + + H = qml.pulse.transmon_drive(qml.pulse.constant, 0.0, 0.0, wires=[0]) + atol = 1e-5 + dev = qml.device("default.qubit.jax", wires=1) + + def ansatz(params): + qml.evolve(H, atol=atol)(params, t=timespan) + return qml.expval(qml.PauliZ(0)) + + cost = qml.QNode(ansatz, dev, interface="jax", diff_method=qml.gradients.stoch_pulse_grad) + cost_jax = qml.QNode(ansatz, dev, interface="jax") + params = (0.42,) + + gradfn = jax.grad(cost) + res = gradfn(params) + exact = jax.grad(cost_jax)(params) + assert qml.math.allclose(res, exact, atol=6e-5) + jax.clear_caches() + + def test_with_drive_approx(self): + """Test that a HardwareHamiltonian only containing a drive is differentiated + approximately correctly for a constant phase and zero frequency.""" + import jax + + timespan = 0.1 + + H = qml.pulse.transmon_drive(1 / (2 * np.pi), qml.pulse.constant, 0.0, wires=[0]) + atol = 1e-5 + dev = qml.device("default.qubit.jax", wires=1) + + def ansatz(params): + qml.evolve(H, atol=atol)(params, t=timespan) + return qml.expval(qml.PauliX(0)) + + cost = qml.QNode( + ansatz, + dev, + interface="jax", + diff_method=qml.gradients.stoch_pulse_grad, + num_split_times=7, + use_broadcasting=True, + sampler_seed=4123, + ) + cost_jax = qml.QNode(ansatz, dev, interface="jax") + params = (0.42,) + + gradfn = jax.grad(cost) + res = gradfn(params) + exact = jax.grad(cost_jax)(params) + assert qml.math.allclose(res, exact, atol=1e-3) + jax.clear_caches() + + @pytest.mark.parametrize("num_params", [1, 2]) + def test_with_two_drives(self, num_params): + """Test that a HardwareHamiltonian only containing two drives + is differentiated approximately correctly. The two cases + of the parametrization test the cases where reordered parameters + are returned as inner lists and where they remain scalars.""" + import jax + + timespan = 0.1 + + if num_params == 1: + amps = [1 / 5, 1 / 6] + params = (0.42, -0.91) + else: + amps = [qml.pulse.constant] * 2 + params = (1 / (2 * np.pi), 0.42, 1 / 5, -0.91) + H = qml.pulse.rydberg_drive( + amps[0], qml.pulse.constant, 0.0, wires=[0] + ) + qml.pulse.rydberg_drive(amps[1], qml.pulse.constant, 0.0, wires=[1]) + atol = 1e-5 + dev = qml.device("default.qubit.jax", wires=2) + + def ansatz(params): + qml.evolve(H, atol=atol)(params, t=timespan) + return qml.expval(qml.PauliX(0) @ qml.PauliX(1)) + + cost = qml.QNode( + ansatz, + dev, + interface="jax", + diff_method=qml.gradients.stoch_pulse_grad, + num_split_times=7, + use_broadcasting=True, + sampler_seed=4123, + ) + cost_jax = qml.QNode(ansatz, dev, interface="jax") + + gradfn = jax.grad(cost) + res = gradfn(params) + exact = jax.grad(cost_jax)(params) + assert qml.math.allclose(res, exact, atol=1e-3) + jax.clear_caches() @pytest.mark.jax From 7f89102db580e954da1257f8c5e55be5e4798ade Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Mon, 19 Jun 2023 19:42:28 -0400 Subject: [PATCH 2/5] Fix batching of derivative tapes in autograd (#4245) --- doc/releases/changelog-0.31.0.md | 2 + pennylane/interfaces/autograd.py | 24 ++++++----- .../test_autograd_default_qubit_2.py | 22 ++++++++++ tests/interfaces/test_autograd_new.py | 43 +++++++++++++++++++ 4 files changed, 81 insertions(+), 10 deletions(-) diff --git a/doc/releases/changelog-0.31.0.md b/doc/releases/changelog-0.31.0.md index 21d1511dea6..f00da904469 100644 --- a/doc/releases/changelog-0.31.0.md +++ b/doc/releases/changelog-0.31.0.md @@ -64,6 +64,8 @@

Improvements 🛠

+* The autograd interface now submits all required tapes in a single batch on the backward pass. + [(#4245)](https://github.com/PennyLaneAI/pennylane/pull/4245) * The experimental device interface is integrated with the `QNode`. [(#4196)](https://github.com/PennyLaneAI/pennylane/pull/4196) diff --git a/pennylane/interfaces/autograd.py b/pennylane/interfaces/autograd.py index 8a95421da06..1240e7c94db 100644 --- a/pennylane/interfaces/autograd.py +++ b/pennylane/interfaces/autograd.py @@ -411,18 +411,22 @@ def _get_jac_with_caching(): return cached_jac["jacobian"] jacs = [] - for t in tapes: - if isinstance(device, qml.devices.experimental.Device): # pragma: no-cover - # cant test until we integrate device with shot vector - shot_vector = t.shots.shot_vector if t.shots.has_partitioned_shots else None - else: - shot_vector = device.shot_vector - g_tapes, fn = gradient_fn(t, shots=shot_vector, **gradient_kwargs) + if isinstance(device, qml.devices.experimental.Device): + shot_vector = ( + tapes[0].shots.shot_vector if tapes[0].shots.has_partitioned_shots else None + ) + else: + shot_vector = device.shot_vector - unwrapped_tapes = tuple(convert_to_numpy_parameters(g_t) for g_t in g_tapes) - res, _ = execute_fn(unwrapped_tapes, **gradient_kwargs) - jacs.append(fn(res)) + def partial_gradient_fn(tape): + return gradient_fn(tape, shots=shot_vector, **gradient_kwargs) + + g_tapes, fn = qml.transforms.map_batch_transform(partial_gradient_fn, tapes) + unwrapped_tapes = tuple(convert_to_numpy_parameters(g_t) for g_t in g_tapes) + + res, _ = execute_fn(unwrapped_tapes, **gradient_kwargs) + jacs = fn(res) cached_jac["jacobian"] = jacs return jacs diff --git a/tests/interfaces/default_qubit_2_integration/test_autograd_default_qubit_2.py b/tests/interfaces/default_qubit_2_integration/test_autograd_default_qubit_2.py index cab0de4457c..c4a4198ce52 100644 --- a/tests/interfaces/default_qubit_2_integration/test_autograd_default_qubit_2.py +++ b/tests/interfaces/default_qubit_2_integration/test_autograd_default_qubit_2.py @@ -97,6 +97,28 @@ def cost(x, cache): assert tracker2.totals["executions"] == expected_runs_ideal assert expected_runs_ideal < expected_runs + def test_single_backward_pass_batch(self): + """Tests that the backward pass is one single batch, not a bunch of batches, when parameter shift + is requested for multiple tapes.""" + + dev = DefaultQubit2() + + def f(x): + tape1 = qml.tape.QuantumScript([qml.RX(x, 0)], [qml.probs(wires=0)]) + tape2 = qml.tape.QuantumScript([qml.RY(x, 0)], [qml.probs(wires=0)]) + + results = qml.execute([tape1, tape2], dev, gradient_fn=qml.gradients.param_shift) + return results[0] + results[1] + + x = qml.numpy.array(0.1) + with dev.tracker: + out = qml.jacobian(f)(x) + + assert dev.tracker.totals["batches"] == 2 + assert dev.tracker.history["executions"] == [2, 4] + expected = [-2 * np.cos(x / 2) * np.sin(x / 2), 2 * np.sin(x / 2) * np.cos(x / 2)] + assert qml.math.allclose(out, expected) + # add tests for lightning 2 when possible # set rng for device when possible diff --git a/tests/interfaces/test_autograd_new.py b/tests/interfaces/test_autograd_new.py index 68304509829..7dd1bca163c 100644 --- a/tests/interfaces/test_autograd_new.py +++ b/tests/interfaces/test_autograd_new.py @@ -415,6 +415,49 @@ def cost(a, cache): grad1 = jac_fn(params, cache=True) assert dev.num_executions == 2 + def test_single_backward_pass_batch(self): + """Tests that the backward pass is one single batch, not a bunch of batches, when parameter shift + is requested for multiple tapes.""" + + dev = qml.device("default.qubit", wires=2) + + def f(x): + tape1 = qml.tape.QuantumScript([qml.RX(x, 0)], [qml.probs(wires=0)]) + tape2 = qml.tape.QuantumScript([qml.RY(x, 0)], [qml.probs(wires=0)]) + + results = qml.execute([tape1, tape2], dev, gradient_fn=qml.gradients.param_shift) + return results[0] + results[1] + + x = qml.numpy.array(0.1) + with dev.tracker: + out = qml.jacobian(f)(x) + + assert dev.tracker.totals["batches"] == 2 + assert dev.tracker.history["batch_len"] == [2, 4] + expected = [-2 * np.cos(x / 2) * np.sin(x / 2), 2 * np.sin(x / 2) * np.cos(x / 2)] + assert qml.math.allclose(out, expected) + + def test_single_backward_pass_split_hamiltonian(self): + """Tests that the backward pass is one single batch, not a bunch of batches, when parameter shift + derivatives are requested for a a tape that the device split into batches.""" + + dev = qml.device("default.qubit", wires=2) + + H = qml.Hamiltonian([1, 1], [qml.PauliY(0), qml.PauliZ(0)], grouping_type="qwc") + + def f(x): + tape = qml.tape.QuantumScript([qml.RX(x, wires=0)], [qml.expval(H)]) + return qml.execute([tape], dev, gradient_fn=qml.gradients.param_shift)[0] + + x = qml.numpy.array(0.1) + with dev.tracker: + out = qml.grad(f)(x) + + assert dev.tracker.totals["batches"] == 2 + assert dev.tracker.history["batch_len"] == [2, 4] + + assert qml.math.allclose(out, -np.cos(x) - np.sin(x)) + execute_kwargs = [ {"gradient_fn": param_shift}, From ecd93931d86bce75254b0c59893729b9ed209f42 Mon Sep 17 00:00:00 2001 From: David Wierichs Date: Tue, 20 Jun 2023 15:59:50 +0200 Subject: [PATCH 3/5] Fix `expval` of `Sum` with broadcasting (#4275) * fix bug and add test * changelog addition --- doc/releases/changelog-0.31.0.md | 1 + pennylane/devices/default_qubit.py | 9 ++++-- tests/devices/test_default_qubit.py | 45 +++++++++++++++++------------ 3 files changed, 34 insertions(+), 21 deletions(-) diff --git a/doc/releases/changelog-0.31.0.md b/doc/releases/changelog-0.31.0.md index f00da904469..4213ac92337 100644 --- a/doc/releases/changelog-0.31.0.md +++ b/doc/releases/changelog-0.31.0.md @@ -447,6 +447,7 @@ * Allow for `Sum` observables with trainable parameters. [(#4251)](https://github.com/PennyLaneAI/pennylane/pull/4251) + [(#4275)](https://github.com/PennyLaneAI/pennylane/pull/4275)

Contributors ✍️

diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index 35a0a5da1c8..ed33195fbae 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -570,10 +570,13 @@ def expval(self, observable, shot_range=None, bin_size=None): Hamiltonian is not NumPy or Autograd """ + is_state_batched = self._ndim(self.state) == 2 # intercept Sums if isinstance(observable, Sum) and not self.shots: return measure( - ExpectationMP(observable.map_wires(self.wire_map)), self._pre_rotated_state + ExpectationMP(observable.map_wires(self.wire_map)), + self._pre_rotated_state, + is_state_batched, ) # intercept other Hamiltonians @@ -592,7 +595,7 @@ def expval(self, observable, shot_range=None, bin_size=None): if backprop_mode: # TODO[dwierichs]: This branch is not adapted to broadcasting yet - if self._ndim(self.state) == 2: + if is_state_batched: raise NotImplementedError( "Expectation values of Hamiltonians for interface!=None are " "not supported together with parameter broadcasting yet" @@ -632,7 +635,7 @@ def expval(self, observable, shot_range=None, bin_size=None): Hmat = observable.sparse_matrix(wire_order=self.wires) state = qml.math.toarray(self.state) - if self._ndim(state) == 2: + if is_state_batched: res = qml.math.array( [ csr_matrix.dot( diff --git a/tests/devices/test_default_qubit.py b/tests/devices/test_default_qubit.py index 85a2ef86781..52788c14ba6 100644 --- a/tests/devices/test_default_qubit.py +++ b/tests/devices/test_default_qubit.py @@ -19,6 +19,7 @@ # pylint: disable=protected-access,cell-var-from-loop import math +from functools import partial import pytest import pennylane as qml @@ -2361,14 +2362,20 @@ def test_Hamiltonian_filtered_from_rotations(self, mocker): assert qml.equal(call_args.measurements[0], qml.expval(qml.PauliX(0))) +@pytest.mark.parametrize("is_state_batched", [False, True]) class TestSumSupport: """Tests for custom Sum support in DefaultQubit.""" - expected_grad = [-np.sin(1.3), np.cos(1.3)] + @staticmethod + def expected_grad(is_state_batched): + if is_state_batched: + return [[-np.sin(1.3), -np.sin(0.4)], [np.cos(1.3), np.cos(0.4)]] + return [-np.sin(1.3), np.cos(1.3)] @staticmethod - def circuit(y, z): - qml.RX(1.3, 0) + def circuit(y, z, is_state_batched): + rx_param = [1.3, 0.4] if is_state_batched else 1.3 + qml.RX(rx_param, 0) return qml.expval( qml.sum( qml.s_prod(y, qml.PauliY(0)), @@ -2376,7 +2383,7 @@ def circuit(y, z): ) ) - def test_super_expval_not_called(self, mocker): + def test_super_expval_not_called(self, is_state_batched, mocker): """Tests basic expval result, and ensures QubitDevice.expval is not called.""" dev = qml.device("default.qubit", wires=1) spy = mocker.spy(qml.QubitDevice, "expval") @@ -2385,28 +2392,30 @@ def test_super_expval_not_called(self, mocker): spy.assert_not_called() @pytest.mark.autograd - def test_trainable_autograd(self): + def test_trainable_autograd(self, is_state_batched): """Tests that coeffs passed to a sum are trainable with autograd.""" + if is_state_batched: + pytest.skip(msg="Broadcasting, qml.jacobian and new return types do not work together") dev = qml.device("default.qubit", wires=1) qnode = qml.QNode(self.circuit, dev, interface="autograd") y, z = np.array([1.1, 2.2]) - actual = qml.grad(qnode)(y, z) - assert np.allclose(actual, self.expected_grad) + actual = qml.grad(qnode, argnum=[0, 1])(y, z, is_state_batched) + assert np.allclose(actual, self.expected_grad(is_state_batched)) @pytest.mark.torch - def test_trainable_torch(self): + def test_trainable_torch(self, is_state_batched): """Tests that coeffs passed to a sum are trainable with torch.""" import torch dev = qml.device("default.qubit", wires=1) qnode = qml.QNode(self.circuit, dev, interface="torch") y, z = torch.tensor(1.1, requires_grad=True), torch.tensor(2.2, requires_grad=True) - qnode(y, z).backward() - actual = [y.grad, z.grad] - assert np.allclose(actual, self.expected_grad) + _qnode = partial(qnode, is_state_batched=is_state_batched) + actual = torch.stack(torch.autograd.functional.jacobian(_qnode, (y, z))) + assert np.allclose(actual, self.expected_grad(is_state_batched)) @pytest.mark.tf - def test_trainable_tf(self): + def test_trainable_tf(self, is_state_batched): """Tests that coeffs passed to a sum are trainable with tf.""" import tensorflow as tf @@ -2414,20 +2423,20 @@ def test_trainable_tf(self): qnode = qml.QNode(self.circuit, dev, interface="tensorflow") y, z = tf.Variable(1.1, dtype=tf.float64), tf.Variable(2.2, dtype=tf.float64) with tf.GradientTape() as tape: - res = qnode(y, z) - actual = tape.gradient(res, [y, z]) - assert np.allclose(actual, self.expected_grad) + res = qnode(y, z, is_state_batched) + actual = tape.jacobian(res, [y, z]) + assert np.allclose(actual, self.expected_grad(is_state_batched)) @pytest.mark.jax - def test_trainable_jax(self): + def test_trainable_jax(self, is_state_batched): """Tests that coeffs passed to a sum are trainable with jax.""" import jax dev = qml.device("default.qubit", wires=1) qnode = qml.QNode(self.circuit, dev, interface="jax") y, z = jax.numpy.array([1.1, 2.2]) - actual = jax.grad(qnode, argnums=[0, 1])(y, z) - assert np.allclose(actual, self.expected_grad) + actual = jax.jacobian(qnode, argnums=[0, 1])(y, z, is_state_batched) + assert np.allclose(actual, self.expected_grad(is_state_batched)) class TestGetBatchSize: From 0df4cb4cc447c54bff2c97742cd669224f9cc1f9 Mon Sep 17 00:00:00 2001 From: Edward Jiang <34989448+eddddddy@users.noreply.github.com> Date: Tue, 20 Jun 2023 10:53:55 -0400 Subject: [PATCH 4/5] Various doc fixes (#4268) --- pennylane/pauli/utils.py | 28 ++++++++++++------- pennylane/transforms/core/transform.py | 6 ++-- .../decompositions/single_qubit_unitary.py | 6 ++-- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/pennylane/pauli/utils.py b/pennylane/pauli/utils.py index 5ba806d7912..b020737df3e 100644 --- a/pennylane/pauli/utils.py +++ b/pennylane/pauli/utils.py @@ -51,8 +51,7 @@ def _wire_map_from_pauli_pair(pauli_word_1, pauli_word_2): return {label: i for i, label in enumerate(wire_labels)} -@singledispatch -def is_pauli_word(observable): # pylint:disable=unused-argument +def is_pauli_word(observable): """ Checks if an observable instance consists only of Pauli and Identity Operators. @@ -93,36 +92,45 @@ def is_pauli_word(observable): # pylint:disable=unused-argument >>> is_pauli_word(4 * qml.PauliX(0) @ qml.PauliZ(0)) True """ + return _is_pauli_word(observable) + + +@singledispatch +def _is_pauli_word(observable): # pylint:disable=unused-argument + """ + Private implementation of is_pauli_word, to prevent all of the + registered functions from appearing in the Sphinx docs. + """ return False -@is_pauli_word.register(PauliX) -@is_pauli_word.register(PauliY) -@is_pauli_word.register(PauliZ) -@is_pauli_word.register(Identity) +@_is_pauli_word.register(PauliX) +@_is_pauli_word.register(PauliY) +@_is_pauli_word.register(PauliZ) +@_is_pauli_word.register(Identity) def _is_pw_pauli( observable: Union[PauliX, PauliY, PauliZ, Identity] ): # pylint:disable=unused-argument return True -@is_pauli_word.register +@_is_pauli_word.register def _is_pw_tensor(observable: Tensor): pauli_word_names = ["Identity", "PauliX", "PauliY", "PauliZ"] return set(observable.name).issubset(pauli_word_names) -@is_pauli_word.register +@_is_pauli_word.register def _is_pw_ham(observable: Hamiltonian): return False if len(observable.ops) != 1 else is_pauli_word(observable.ops[0]) -@is_pauli_word.register +@_is_pauli_word.register def _is_pw_prod(observable: Prod): return all(is_pauli_word(op) for op in observable) -@is_pauli_word.register +@_is_pauli_word.register def _is_pw_sprod(observable: SProd): return is_pauli_word(observable.base) diff --git a/pennylane/transforms/core/transform.py b/pennylane/transforms/core/transform.py index 85bba6262f7..bc2736da9a2 100644 --- a/pennylane/transforms/core/transform.py +++ b/pennylane/transforms/core/transform.py @@ -58,9 +58,9 @@ def post_processing_fn(results): return [tape1, tape2], post_processing_fn - Of course, we want to be able to apply this transform on `qfunc` and `qnodes`. That's where the `transform` function + Of course, we want to be able to apply this transform on ``qfunc`` and ``qnodes``. That's where the ``transform`` function comes into play. This function validates the signature of your quantum transform and dispatches it on the different - object. Let's define a circuit as a qfunc and as qnode. + object. Let's define a circuit as a qfunc and as a qnode. .. code-block:: python @@ -85,7 +85,7 @@ def qnode_circuit(a): Now you can use the dispatched transform directly on qfunc and qnodes. - For QNodes, the dispatched transform populates the `TransformProgram` of your QNode. The transform and its + For QNodes, the dispatched transform populates the ``TransformProgram`` of your QNode. The transform and its processing function are applied in the execution. >>> transformed_qnode = dispatched_transform(qfunc_circuit) diff --git a/pennylane/transforms/decompositions/single_qubit_unitary.py b/pennylane/transforms/decompositions/single_qubit_unitary.py index 6ffba710d5e..fee334cc910 100644 --- a/pennylane/transforms/decompositions/single_qubit_unitary.py +++ b/pennylane/transforms/decompositions/single_qubit_unitary.py @@ -417,10 +417,10 @@ def _zxz_decomposition(U, wire, return_global_phase=False): def one_qubit_decomposition(U, wire, rotations="ZYZ", return_global_phase=False): r"""Decompose a one-qubit unitary :math:`U` in terms of elementary operations. (batched operation) - Any one qubit unitary operation can be implemented upto a global phase by composing RX, RY, + Any one qubit unitary operation can be implemented up to a global phase by composing RX, RY, and RZ gates. - Currently supported values for `rotations` are "ZYZ", "XYX", and "ZXZ". + Currently supported values for ``rotations`` are "ZYZ", "XYX", and "ZXZ". Args: U (tensor): A :math:`2 \times 2` unitary matrix. @@ -431,7 +431,7 @@ def one_qubit_decomposition(U, wire, rotations="ZYZ", return_global_phase=False) Returns: list[Operation]: Returns a list of gates which when applied in the order of appearance in - the list is equivalent to the unitary :math:`U` up to a global phase. If `return_global_phase=True`, + the list is equivalent to the unitary :math:`U` up to a global phase. If ``return_global_phase=True``, the global phase is returned as the last element of the list. **Example** From 9ac4ace0581effcf83516fbdf1481b8dc58ca449 Mon Sep 17 00:00:00 2001 From: GitHub Actions Bot <> Date: Wed, 21 Jun 2023 02:52:24 +0000 Subject: [PATCH 5/5] exclude files from pr --- doc/development/release_notes.md | 2 ++ pennylane/_version.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/development/release_notes.md b/doc/development/release_notes.md index bd931c5ff36..1be9bcfea01 100644 --- a/doc/development/release_notes.md +++ b/doc/development/release_notes.md @@ -3,6 +3,8 @@ Release notes This page contains the release notes for PennyLane. +.. mdinclude:: ../releases/changelog-dev.md + .. mdinclude:: ../releases/changelog-0.31.0.md .. mdinclude:: ../releases/changelog-0.30.0.md diff --git a/pennylane/_version.py b/pennylane/_version.py index afaacd4e0ce..3b47daed981 100644 --- a/pennylane/_version.py +++ b/pennylane/_version.py @@ -16,4 +16,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "0.31.0" +__version__ = "0.32.0-dev"