diff --git a/doc/releases/changelog-0.32.0.md b/doc/releases/changelog-0.32.0.md index 72f3f9bc225..fa0b6ad2bd9 100644 --- a/doc/releases/changelog-0.32.0.md +++ b/doc/releases/changelog-0.32.0.md @@ -454,6 +454,9 @@ array([False, False]) [(4165)](https://github.com/PennyLaneAI/pennylane/pull/4165) [(4482)](https://github.com/PennyLaneAI/pennylane/pull/4482) +* The backprop gradient of `qml.math.fidelity` is now correct. + [(#4380)](https://github.com/PennyLaneAI/pennylane/pull/4380) +

Contributors ✍️

diff --git a/pennylane/math/__init__.py b/pennylane/math/__init__.py index e6ea3aec4f7..93b822f8289 100644 --- a/pennylane/math/__init__.py +++ b/pennylane/math/__init__.py @@ -67,8 +67,6 @@ from .quantum import ( cov_matrix, dm_from_state_vector, - fidelity, - fidelity_statevector, marginal_prob, mutual_info, purity, @@ -80,6 +78,7 @@ max_entropy, trace_distance, ) +from .fidelity import fidelity, fidelity_statevector from .utils import ( allclose, allequal, diff --git a/pennylane/math/fidelity.py b/pennylane/math/fidelity.py new file mode 100644 index 00000000000..97989a085a2 --- /dev/null +++ b/pennylane/math/fidelity.py @@ -0,0 +1,357 @@ +# Copyright 2018-2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Contains the implementation of quantum fidelity. + +Note: care needs to be taken to make it fully differentiable. An explanation can +be found in pennylane/math/fidelity_gradient.md +""" +from functools import lru_cache + +import autograd +import autoray as ar +import pennylane as qml + +from .utils import cast +from .quantum import _check_density_matrix, _check_state_vector + + +def fidelity_statevector(state0, state1, check_state=False, c_dtype="complex128"): + r"""Compute the fidelity for two states (given as state vectors) acting on quantum + systems with the same size. + + The fidelity for two pure states given by state vectors :math:`\ket{\psi}` and :math:`\ket{\phi}` + is defined as + + .. math:: + F( \ket{\psi} , \ket{\phi}) = \left|\braket{\psi, \phi}\right|^2 + + This is faster than calling :func:`pennylane.math.fidelity` on the density matrix + representation of pure states. + + .. note:: + It supports all interfaces (Numpy, Autograd, Torch, Tensorflow and Jax). The second state is coerced + to the type and dtype of the first state. The fidelity is returned in the type of the interface of the + first state. + + Args: + state0 (tensor_like): ``(2**N)`` or ``(batch_dim, 2**N)`` state vector. + state1 (tensor_like): ``(2**N)`` or ``(batch_dim, 2**N)`` state vector. + check_state (bool): If True, the function will check the validity of both states; that is, + the shape and the norm + c_dtype (str): Complex floating point precision type. + + Returns: + float: Fidelity between the two quantum states. + + **Example** + + Two state vectors can be used as arguments and the fidelity (overlap) is returned, e.g.: + + >>> state0 = [0.98753537-0.14925137j, 0.00746879-0.04941796j] + >>> state1 = [0.99500417+0.j, 0.09983342+0.j] + >>> qml.math.fidelity(state0, state1) + 0.9905158135644924 + + .. seealso:: :func:`pennylane.math.fidelity` and :func:`pennylane.qinfo.transforms.fidelity` + + """ + # Cast as a c_dtype array + state0 = cast(state0, dtype=c_dtype) + state1 = cast(state1, dtype=c_dtype) + + if check_state: + _check_state_vector(state0) + _check_state_vector(state1) + + if qml.math.shape(state0)[-1] != qml.math.shape(state1)[-1]: + raise qml.QuantumFunctionError("The two states must have the same number of wires.") + + batched0 = len(qml.math.shape(state0)) > 1 + batched1 = len(qml.math.shape(state1)) > 1 + + # Two pure states, squared overlap + indices0 = "ab" if batched0 else "b" + indices1 = "ab" if batched1 else "b" + target = "a" if batched0 or batched1 else "" + overlap = qml.math.einsum( + f"{indices0},{indices1}->{target}", state0, qml.math.conj(state1), optimize="greedy" + ) + + overlap = qml.math.abs(overlap) ** 2 + return overlap + + +def fidelity(state0, state1, check_state=False, c_dtype="complex128"): + r"""Compute the fidelity for two states (given as density matrices) acting on quantum + systems with the same size. + + The fidelity for two mixed states given by density matrices :math:`\rho` and :math:`\sigma` + is defined as + + .. math:: + F( \rho , \sigma ) = \text{Tr}( \sqrt{\sqrt{\rho} \sigma \sqrt{\rho}})^2 + + .. note:: + It supports all interfaces (Numpy, Autograd, Torch, Tensorflow and Jax). The second state is coerced + to the type and dtype of the first state. The fidelity is returned in the type of the interface of the + first state. + + Args: + state0 (tensor_like): ``(2**N, 2**N)`` or ``(batch_dim, 2**N, 2**N)`` density matrix. + state1 (tensor_like): ``(2**N, 2**N)`` or ``(batch_dim, 2**N, 2**N)`` density matrix. + check_state (bool): If True, the function will check the validity of both states; that is, + (shape, trace, positive-definitiveness) for density matrices. + c_dtype (str): Complex floating point precision type. + + Returns: + float: Fidelity between the two quantum states. + + **Example** + + To find the fidelity between two state vectors, call :func:`~.math.dm_from_state_vector` on the + inputs first, e.g.: + + >>> state0 = qml.math.dm_from_state_vector([0.98753537-0.14925137j, 0.00746879-0.04941796j]) + >>> state1 = qml.math.dm_from_state_vector([0.99500417+0.j, 0.09983342+0.j]) + >>> qml.math.fidelity(state0, state1) + 0.9905158135644924 + + To find the fidelity between two density matrices, they can be passed directly: + + >>> state0 = [[1, 0], [0, 0]] + >>> state1 = [[0, 0], [0, 1]] + >>> qml.math.fidelity(state0, state1) + 0.0 + + .. seealso:: :func:`pennylane.math.fidelity_statevector` and :func:`pennylane.qinfo.transforms.fidelity` + + """ + # Cast as a c_dtype array + state0 = cast(state0, dtype=c_dtype) + state1 = cast(state1, dtype=c_dtype) + + if check_state: + _check_density_matrix(state0) + _check_density_matrix(state1) + + if qml.math.shape(state0)[-1] != qml.math.shape(state1)[-1]: + raise qml.QuantumFunctionError("The two states must have the same number of wires.") + + # Two mixed states + _register_vjp(state0, state1) + fid = qml.math.compute_fidelity(state0, state1) + return fid + + +def _register_vjp(state0, state1): + """ + Register the interface-specific custom VJP based on the interfaces of the given states + + This function is needed because we don't want to register the custom + VJPs at PennyLane import time. + """ + interface = qml.math.get_interface(state0, state1) + if interface == "jax": + _register_jax_vjp() + elif interface == "torch": + _register_torch_vjp() + elif interface == "tensorflow": + _register_tf_vjp() + + +def _compute_fidelity_vanilla(density_matrix0, density_matrix1): + r"""Compute the fidelity for two density matrices with the same number of wires. + + .. math:: + F( \rho , \sigma ) = -\text{Tr}( \sqrt{\sqrt{\rho} \sigma \sqrt{\rho}})^2 + """ + # Implementation in single dispatches (sqrt(rho)) + sqrt_mat = qml.math.sqrt_matrix(density_matrix0) + + # sqrt(rho) * sigma * sqrt(rho) + sqrt_mat_sqrt = sqrt_mat @ density_matrix1 @ sqrt_mat + + # extract eigenvalues + evs = qml.math.eigvalsh(sqrt_mat_sqrt) + evs = qml.math.real(evs) + evs = qml.math.where(evs > 0.0, evs, 0) + + trace = (qml.math.sum(qml.math.sqrt(evs), -1)) ** 2 + + return trace + + +def _compute_fidelity_vjp0(dm0, dm1, grad_out): + """ + Compute the VJP of fidelity with respect to the first density matrix + """ + # sqrt of sigma + sqrt_dm1 = qml.math.sqrt_matrix(dm1) + + # eigendecomposition of sqrt(sigma) * rho * sqrt(sigma) + evs0, u0 = qml.math.linalg.eigh(sqrt_dm1 @ dm0 @ sqrt_dm1) + evs0 = qml.math.real(evs0) + evs0 = qml.math.where(evs0 > 1e-15, evs0, 1e-15) + evs0 = qml.math.cast_like(evs0, sqrt_dm1) + + if len(qml.math.shape(dm0)) == 2 and len(qml.math.shape(dm1)) == 2: + u0_dag = qml.math.transpose(qml.math.conj(u0)) + grad_dm0 = sqrt_dm1 @ u0 @ (1 / qml.math.sqrt(evs0)[..., None] * u0_dag) @ sqrt_dm1 + + # torch and tensorflow use the Wirtinger derivative which is a different convention + # than the one autograd and jax use for complex differentiation + if qml.math.get_interface(dm0) in ["torch", "tensorflow"]: + grad_dm0 = qml.math.sum(qml.math.sqrt(evs0), -1) * grad_dm0 + else: + grad_dm0 = qml.math.sum(qml.math.sqrt(evs0), -1) * qml.math.transpose(grad_dm0) + + res = grad_dm0 * qml.math.cast_like(grad_out, grad_dm0) + return res + + # broadcasting case + u0_dag = qml.math.transpose(qml.math.conj(u0), (0, 2, 1)) + grad_dm0 = sqrt_dm1 @ u0 @ (1 / qml.math.sqrt(evs0)[..., None] * u0_dag) @ sqrt_dm1 + + # torch and tensorflow use the Wirtinger derivative which is a different convention + # than the one autograd and jax use for complex differentiation + if qml.math.get_interface(dm0) in ["torch", "tensorflow"]: + grad_dm0 = qml.math.sum(qml.math.sqrt(evs0), -1)[:, None, None] * grad_dm0 + else: + grad_dm0 = qml.math.sum(qml.math.sqrt(evs0), -1)[:, None, None] * qml.math.transpose( + grad_dm0, (0, 2, 1) + ) + + return grad_dm0 * qml.math.cast_like(grad_out, grad_dm0)[:, None, None] + + +def _compute_fidelity_vjp1(dm0, dm1, grad_out): + """ + Compute the VJP of fidelity with respect to the second density matrix + """ + # pylint: disable=arguments-out-of-order + return _compute_fidelity_vjp0(dm1, dm0, grad_out) + + +def _compute_fidelity_grad(dm0, dm1, grad_out): + return _compute_fidelity_vjp0(dm0, dm1, grad_out), _compute_fidelity_vjp1(dm0, dm1, grad_out) + + +################################ numpy ################################### + +ar.register_function("numpy", "compute_fidelity", _compute_fidelity_vanilla) + +################################ autograd ################################ + + +@autograd.extend.primitive +def _compute_fidelity_autograd(dm0, dm1): + return _compute_fidelity_vanilla(dm0, dm1) + + +def _compute_fidelity_autograd_vjp0(_, dm0, dm1): + def vjp(grad_out): + return _compute_fidelity_vjp0(dm0, dm1, grad_out) + + return vjp + + +def _compute_fidelity_autograd_vjp1(_, dm0, dm1): + def vjp(grad_out): + return _compute_fidelity_vjp1(dm0, dm1, grad_out) + + return vjp + + +autograd.extend.defvjp( + _compute_fidelity_autograd, _compute_fidelity_autograd_vjp0, _compute_fidelity_autograd_vjp1 +) +ar.register_function("autograd", "compute_fidelity", _compute_fidelity_autograd) + +################################# jax ##################################### + + +@lru_cache(maxsize=None) +def _register_jax_vjp(): + """ + Register the custom VJP for JAX + """ + # pylint: disable=import-outside-toplevel + import jax + + @jax.custom_vjp + def _compute_fidelity_jax(dm0, dm1): + return _compute_fidelity_vanilla(dm0, dm1) + + def _compute_fidelity_jax_fwd(dm0, dm1): + fid = _compute_fidelity_jax(dm0, dm1) + return fid, (dm0, dm1) + + def _compute_fidelity_jax_bwd(res, grad_out): + dm0, dm1 = res + return _compute_fidelity_grad(dm0, dm1, grad_out) + + _compute_fidelity_jax.defvjp(_compute_fidelity_jax_fwd, _compute_fidelity_jax_bwd) + ar.register_function("jax", "compute_fidelity", _compute_fidelity_jax) + + +################################ torch ################################### + + +@lru_cache(maxsize=None) +def _register_torch_vjp(): + """ + Register the custom VJP for torch + """ + # pylint: disable=import-outside-toplevel + import torch + + class _TorchFidelity(torch.autograd.Function): + @staticmethod + def forward(ctx, dm0, dm1): + """Forward pass for _compute_fidelity""" + fid = _compute_fidelity_vanilla(dm0, dm1) + ctx.save_for_backward(dm0, dm1) + return fid + + @staticmethod + def backward(ctx, grad_out): + """Backward pass for _compute_fidelity""" + dm0, dm1 = ctx.saved_tensors + return _compute_fidelity_grad(dm0, dm1, grad_out) + + ar.register_function("torch", "compute_fidelity", _TorchFidelity.apply) + + +############################### tensorflow ################################ + + +@lru_cache(maxsize=None) +def _register_tf_vjp(): + """ + Register the custom VJP for tensorflow + """ + # pylint: disable=import-outside-toplevel + import tensorflow as tf + + @tf.custom_gradient + def _compute_fidelity_tf(dm0, dm1): + fid = _compute_fidelity_vanilla(dm0, dm1) + + def vjp(grad_out): + return _compute_fidelity_grad(dm0, dm1, grad_out) + + return fid, vjp + + ar.register_function("tensorflow", "compute_fidelity", _compute_fidelity_tf) diff --git a/pennylane/math/fidelity_gradient.md b/pennylane/math/fidelity_gradient.md new file mode 100644 index 00000000000..603eb45260a --- /dev/null +++ b/pennylane/math/fidelity_gradient.md @@ -0,0 +1,73 @@ + +# Gradient of fidelity + +## Issue with autodiff +The fidelity between two density operators $\rho$ and $\sigma$ is given by + +$$\begin{align} +F(\rho,\sigma)=\text{Tr}\left(\sqrt{\sqrt{\rho}\\,\sigma\\,\sqrt{\rho}}\right)^2\qquad\qquad\qquad\text{(1)} +\end{align}$$ + +Looking at the above, we multiply the square root of the operator $\rho$ with a generally non-commuting operator $\sigma$, so it is clear that computing $F$ requires the full eigendecomposition of $\rho$, including its eigenvectors. For comparison, the trace distance between $\rho$ and $\sigma$ is given by + +$$\begin{align*} +T(\rho,\sigma)=\frac{1}{2}\text{Tr}\left(|\rho-\sigma|\right) +\end{align*}$$ + +which does not require the eigenvectors of $\rho$, $\sigma$, or any function of the two, since the trace can be directly computed from the eigenvalues of $\rho-\sigma$. + +The backward gradient of the eigendecomposition $\rho=U\Lambda U^\dagger$ is given by + +$$\frac{\mathcal{\partial L}}{\partial\rho}=U^*\left(\frac{\partial\mathcal{L}}{\partial \Lambda}+D\circ U^\dagger\frac{\partial\mathcal{L}}{\partial U}\right)U^T$$ + +where $D$ is the matrix defined as + +$$D_{ij}=\begin{cases}\frac{1}{\lambda_j-\lambda_i}\quad&\text{if }i\neq j\\ +0&\text{if }i=j\end{cases}$$ + +See section 4.17 [here](https://arxiv.org/pdf/1701.00392.pdf) for more details. + +It is clear that if $\rho$ has two or more eigenvalues which are the same, then the gradient of the eigendecomposition will be undefined. Indeed, auto-differentiation frameworks like JAX produce NaN results when backpropagating through $F$ if the inputs $\rho$ or $\sigma$ are sparse (which is often the case for states close to pure). On the other hand, the fidelity is a measure of overlap and thus has a well-defined gradient in every scenario. + +Our solution to the above is to skip the gradient computation of the eigendecomposition, and instead treat the fidelity function as a black-box while defining a custom gradient. + +## Gradient derivation + +Let $\sqrt{\rho}\\,\sigma\sqrt{\rho}=U\Lambda U^\dagger$ be the eigendecomposition of $\sqrt{\rho}\\,\sigma\sqrt{\rho}$. In einsum notation, we have + +$$\Lambda_{k\ell} = U^\dagger_ {kp}\sqrt{\rho}_ {pm}\sigma_{mn}\sqrt{\rho}_ {nq}U_ {q\ell}$$ + +whence + +$$\frac{\partial\Lambda_{k\ell}}{\partial \sigma_ {mn}}=\sqrt{\rho}_ {nq}U_{q\ell}U^\dagger_{kp}\sqrt{\rho}_{pm}$$ + +Then equation (1) gives + +$$F(\rho,\sigma)=\text{Tr}\left(\sqrt{\sqrt{\rho}\\,\sigma\\,\sqrt{\rho}}\right)^2=\left(\sum_{k=1}^d\sqrt{\Lambda_{kk}}\right)^2$$ + +and the gradient is easily seen to be + +$$\begin{align*} +\frac{\partial F}{\partial\sigma_{mn}}&=2\left(\sum_{k=1}^n\sqrt{\Lambda_{kk}}\right)\sum_{k=1}^d\sqrt{\rho}_ {nq}U_{qk}\frac{1}{2\sqrt{\Lambda_{kk}}}U^\dagger_{kp}\sqrt{\rho}_{pm}\\ +\end{align*}$$ + +Converting back to matrix notation, this is + +$$\begin{align} +\frac{\partial F}{\partial\sigma}=\sqrt{F(\rho,\sigma)}\left(\sqrt{\rho}\\,\left(\sqrt{\rho}\\,\sigma\sqrt{\rho}\right)^{-\frac{1}{2}}\sqrt{\rho}\right)^T\qquad\qquad\qquad\text{(2)} +\end{align}$$ + +Since $F(\rho,\sigma)=F(\sigma,\rho)$, it follows immediately that + +$$\begin{align} +\frac{\partial F}{\partial\rho}=\sqrt{F(\rho,\sigma)}\left(\sqrt{\sigma}\\,\left(\sqrt{\sigma}\\,\rho\sqrt{\sigma}\right)^{-\frac{1}{2}}\sqrt{\sigma}\right)^T\qquad\qquad\qquad\text{(3)} +\end{align}$$ + +Equations (2) and (3) are what we use for the custom gradient of the fidelity function. + +Note that these equations are what Autograd and JAX use for the gradient of a real-valued function defined on complex inputs. On the other hand, PyTorch and TensorFlow use the Wirtinger derivatives given by + +$$\begin{align} +\frac{\partial F}{\partial\rho^\*}&=\sqrt{F(\rho,\sigma)}\sqrt{\sigma}\\,\left(\sqrt{\sigma}\\,\rho\sqrt{\sigma}\right)^{-\frac{1}{2}}\sqrt{\sigma}\\ +\frac{\partial F}{\partial\sigma^\*}&=\sqrt{F(\rho,\sigma)}\sqrt{\rho}\\,\left(\sqrt{\rho}\\,\sigma\sqrt{\rho}\right)^{-\frac{1}{2}}\sqrt{\rho} +\end{align}$$ diff --git a/pennylane/math/quantum.py b/pennylane/math/quantum.py index f7020c89b6f..a27f660ed48 100644 --- a/pennylane/math/quantum.py +++ b/pennylane/math/quantum.py @@ -720,139 +720,6 @@ def _compute_mutual_info( return vn_entropy_1 + vn_entropy_2 - vn_entropy_12 -def fidelity(state0, state1, check_state=False, c_dtype="complex128"): - r"""Compute the fidelity for two states (given as density matrices) acting on quantum - systems with the same size. - - The fidelity for two mixed states given by density matrices :math:`\rho` and :math:`\sigma` - is defined as - - .. math:: - F( \rho , \sigma ) = \text{Tr}( \sqrt{\sqrt{\rho} \sigma \sqrt{\rho}})^2 - - .. note:: - It supports all interfaces (Numpy, Autograd, Torch, Tensorflow and Jax). The second state is coerced - to the type and dtype of the first state. The fidelity is returned in the type of the interface of the - first state. - - Args: - state0 (tensor_like): ``(2**N, 2**N)`` or ``(batch_dim, 2**N, 2**N)`` density matrix. - state1 (tensor_like): ``(2**N, 2**N)`` or ``(batch_dim, 2**N, 2**N)`` density matrix. - check_state (bool): If True, the function will check the validity of both states; that is, - (shape, trace, positive-definitiveness) for density matrices. - c_dtype (str): Complex floating point precision type. - - Returns: - float: Fidelity between the two quantum states. - - **Example** - - To find the fidelity between two state vectors, call :func:`~.math.dm_from_state_vector` on the - inputs first, e.g.: - - >>> state0 = qml.math.dm_from_state_vector([0.98753537-0.14925137j, 0.00746879-0.04941796j]) - >>> state1 = qml.math.dm_from_state_vector([0.99500417+0.j, 0.09983342+0.j]) - >>> qml.math.fidelity(state0, state1) - 0.9905158135644924 - - To find the fidelity between two density matrices, they can be passed directly: - - >>> state0 = [[1, 0], [0, 0]] - >>> state1 = [[0, 0], [0, 1]] - >>> qml.math.fidelity(state0, state1) - 0.0 - - .. seealso:: :func:`pennylane.math.fidelity_statevector` and :func:`pennylane.qinfo.transforms.fidelity` - - """ - # Cast as a c_dtype array - state0 = cast(state0, dtype=c_dtype) - - # Cannot be cast_like if jit - if not is_abstract(state0): - state1 = cast_like(state1, state0) - - if check_state: - _check_density_matrix(state0) - _check_density_matrix(state1) - - if qml.math.shape(state0)[-1] != qml.math.shape(state1)[-1]: - raise qml.QuantumFunctionError("The two states must have the same number of wires.") - - # Two mixed states - fid = _compute_fidelity(state0, state1) - return fid - - -def fidelity_statevector(state0, state1, check_state=False, c_dtype="complex128"): - r"""Compute the fidelity for two states (given as state vectors) acting on quantum - systems with the same size. - - The fidelity for two pure states given by state vectors :math:`\ket{\psi}` and :math:`\ket{\phi}` - is defined as - - .. math:: - F( \ket{\psi} , \ket{\phi}) = \left|\braket{\psi, \phi}\right|^2 - - This is faster than calling :func:`pennylane.math.fidelity` on the density matrix - representation of pure states. - - .. note:: - It supports all interfaces (Numpy, Autograd, Torch, Tensorflow and Jax). The second state is coerced - to the type and dtype of the first state. The fidelity is returned in the type of the interface of the - first state. - - Args: - state0 (tensor_like): ``(2**N)`` or ``(batch_dim, 2**N)`` state vector. - state1 (tensor_like): ``(2**N)`` or ``(batch_dim, 2**N)`` state vector. - check_state (bool): If True, the function will check the validity of both states; that is, - the shape and the norm - c_dtype (str): Complex floating point precision type. - - Returns: - float: Fidelity between the two quantum states. - - **Example** - - Two state vectors can be used as arguments and the fidelity (overlap) is returned, e.g.: - - >>> state0 = [0.98753537-0.14925137j, 0.00746879-0.04941796j] - >>> state1 = [0.99500417+0.j, 0.09983342+0.j] - >>> qml.math.fidelity(state0, state1) - 0.9905158135644924 - - .. seealso:: :func:`pennylane.math.fidelity` and :func:`pennylane.qinfo.transforms.fidelity` - - """ - # Cast as a c_dtype array - state0 = cast(state0, dtype=c_dtype) - - # Cannot be cast_like if jit - if not is_abstract(state0): - state1 = cast_like(state1, state0) - - if check_state: - _check_state_vector(state0) - _check_state_vector(state1) - - if qml.math.shape(state0)[-1] != qml.math.shape(state1)[-1]: - raise qml.QuantumFunctionError("The two states must have the same number of wires.") - - batched0 = len(qml.math.shape(state0)) > 1 - batched1 = len(qml.math.shape(state1)) > 1 - - # Two pure states, squared overlap - indices0 = "ab" if batched0 else "b" - indices1 = "ab" if batched1 else "b" - target = "a" if batched0 or batched1 else "" - overlap = qml.math.einsum( - f"{indices0},{indices1}->{target}", state0, np.conj(state1), optimize="greedy" - ) - - overlap = np.abs(overlap) ** 2 - return overlap - - def sqrt_matrix(density_matrix): r"""Compute the square root matrix of a density matrix where :math:`\rho = \sqrt{\rho} \times \sqrt{\rho}` @@ -863,7 +730,7 @@ def sqrt_matrix(density_matrix): (tensor_like): Square root of the density matrix. """ evs, vecs = qml.math.linalg.eigh(density_matrix) - evs = np.real(evs) + evs = qml.math.real(evs) evs = qml.math.where(evs > 0.0, evs, 0.0) if not is_abstract(evs): evs = qml.math.cast_like(evs, vecs) @@ -871,32 +738,11 @@ def sqrt_matrix(density_matrix): shape = qml.math.shape(density_matrix) if len(shape) > 2: # broadcasting case - sqrt_evs = qml.math.expand_dims(qml.math.sqrt(evs), 1) * qml.math.eye(shape[-1]) + i = qml.math.cast_like(qml.math.convert_like(qml.math.eye(shape[-1]), evs), evs) + sqrt_evs = qml.math.expand_dims(qml.math.sqrt(evs), 1) * i return vecs @ sqrt_evs @ qml.math.conj(qml.math.transpose(vecs, (0, 2, 1))) - return vecs @ qml.math.diag(np.sqrt(evs)) @ np.conj(np.transpose(vecs)) - - -def _compute_fidelity(density_matrix0, density_matrix1): - r"""Compute the fidelity for two density matrices with the same number of wires. - - .. math:: - F( \rho , \sigma ) = -\text{Tr}( \sqrt{\sqrt{\rho} \sigma \sqrt{\rho}})^2 - """ - # Implementation in single dispatches (sqrt(rho)) - sqrt_mat = qml.math.sqrt_matrix(density_matrix0) - - # sqrt(rho) * sigma * sqrt(rho) - sqrt_mat_sqrt = sqrt_mat @ density_matrix1 @ sqrt_mat - - # extract eigenvalues - evs = qml.math.eigvalsh(sqrt_mat_sqrt) - evs = np.real(evs) - evs = qml.math.where(evs > 0.0, evs, 0.0) - - trace = (qml.math.sum(qml.math.sqrt(evs), -1)) ** 2 - - return trace + return vecs @ qml.math.diag(qml.math.sqrt(evs)) @ qml.math.conj(qml.math.transpose(vecs)) def _compute_relative_entropy(rho, sigma, base=None): diff --git a/tests/math/test_fidelity_math.py b/tests/math/test_fidelity_math.py index 1e3db353ac3..da4a38df49a 100644 --- a/tests/math/test_fidelity_math.py +++ b/tests/math/test_fidelity_math.py @@ -216,3 +216,223 @@ def test_broadcast_dm_dm_unbatched(self, check_state, func): fidelity = qml.math.fidelity(state0, state1, check_state) assert qml.math.allclose(fidelity, expected) + + +def cost_fn_single(x): + first_term = qml.math.convert_like(qml.math.diag([1.0, 0]), x) + second_term = qml.math.convert_like(qml.math.diag([0, 1.0]), x) + + x = qml.math.cast_like(x, first_term) + if len(qml.math.shape(x)) == 0: + state1 = qml.math.cos(x / 2) ** 2 * first_term + qml.math.sin(x / 2) ** 2 * second_term + else: + # broadcasting + x = x[:, None, None] + state1 = qml.math.cos(x / 2) ** 2 * first_term + qml.math.sin(x / 2) ** 2 * second_term + + state2 = qml.math.convert_like(qml.math.diag([1, 0]), state1) + + return qml.math.fidelity(state1, state2) + qml.math.fidelity(state2, state1) + + +def cost_fn_multi1(x): + first_term = qml.math.convert_like(qml.math.diag([1.0, 0, 0, 0]), x) + second_term = qml.math.convert_like(qml.math.diag([0, 0, 0, 1.0]), x) + + x = qml.math.cast_like(x, first_term) + + if len(qml.math.shape(x)) == 0: + state1 = qml.math.cos(x / 2) ** 2 * first_term + qml.math.sin(x / 2) ** 2 * second_term + else: + # broadcasting + x = x[:, None, None] + state1 = qml.math.cos(x / 2) ** 2 * first_term + qml.math.sin(x / 2) ** 2 * second_term + + state2 = qml.math.convert_like(qml.math.diag([1, 0, 0, 0]), state1) + + return qml.math.fidelity(state1, state2) + qml.math.fidelity(state2, state1) + + +def cost_fn_multi2(x): + first_term = qml.math.convert_like(np.ones((4, 4)) / 4, x) + second_term = np.zeros((4, 4)) + second_term[1:3, 1:3] = np.array([[1, -1], [-1, 1]]) / 2 + second_term = qml.math.convert_like(second_term, x) + + x = qml.math.cast_like(x, first_term) + + if len(qml.math.shape(x)) == 0: + state1 = qml.math.cos(x / 2) ** 2 * first_term + qml.math.sin(x / 2) ** 2 * second_term + else: + # broadcasting + x = x[:, None, None] + state1 = qml.math.cos(x / 2) ** 2 * first_term + qml.math.sin(x / 2) ** 2 * second_term + + state2 = qml.math.convert_like(qml.math.diag([1, 0, 0, 0]), state1) + + return qml.math.fidelity(state1, state2) + qml.math.fidelity(state2, state1) + + +def expected_res_single(x): + return 2 * qml.math.cos(x / 2) ** 2 + + +def expected_res_multi1(x): + return 2 * qml.math.cos(x / 2) ** 2 + + +def expected_res_multi2(x): + return qml.math.cos(x / 2) ** 2 / 2 + + +def expected_grad_single(x): + return -qml.math.sin(x) + + +def expected_grad_multi1(x): + return -qml.math.sin(x) + + +def expected_grad_multi2(x): + return -qml.math.sin(x) / 4 + + +class TestGradient: + """Test the gradient of qml.math.fidelity""" + + # pylint: disable=too-many-arguments + + cost_fns = [ + (cost_fn_single, expected_res_single, expected_grad_single), + (cost_fn_multi1, expected_res_multi1, expected_grad_multi1), + (cost_fn_multi2, expected_res_multi2, expected_grad_multi2), + ] + + @pytest.mark.autograd + @pytest.mark.parametrize("x", [0.0, 1e-7, 0.456, np.pi / 2 - 1e-7, np.pi / 2]) + @pytest.mark.parametrize("cost_fn, expected_res, expected_grad", cost_fns) + def test_grad_autograd(self, x, cost_fn, expected_res, expected_grad, tol): + """Test gradients are correct for autograd""" + x = np.array(x) + res = cost_fn(x) + grad = qml.grad(cost_fn)(x) + + assert qml.math.allclose(res, expected_res(x), tol) + assert qml.math.allclose(grad, expected_grad(x), tol) + + @pytest.mark.jax + @pytest.mark.parametrize("x", [0.0, 1e-7, 0.456, np.pi / 2 - 1e-7, np.pi / 2]) + @pytest.mark.parametrize("cost_fn, expected_res, expected_grad", cost_fns) + def test_grad_jax(self, x, cost_fn, expected_res, expected_grad, tol): + """Test gradients are correct for jax""" + x = jnp.array(x) + res = cost_fn(x) + grad = jax.grad(cost_fn)(x) + + assert qml.math.allclose(res, expected_res(x), tol) + assert qml.math.allclose(grad, expected_grad(x), tol) + + @pytest.mark.jax + @pytest.mark.parametrize("x", [0.0, 1e-7, 0.456, np.pi / 2 - 1e-7, np.pi / 2]) + @pytest.mark.parametrize("cost_fn, expected_res, expected_grad", cost_fns) + def test_grad_jax_jit(self, x, cost_fn, expected_res, expected_grad, tol): + """Test gradients are correct for jax-jit""" + x = jnp.array(x) + + jitted_cost = jax.jit(cost_fn) + res = jitted_cost(x) + grad = jax.grad(jitted_cost)(x) + + assert qml.math.allclose(res, expected_res(x), tol) + assert qml.math.allclose(grad, expected_grad(x), tol) + + @pytest.mark.torch + @pytest.mark.parametrize("x", [0.0, 1e-7, 0.456, np.pi / 2 - 1e-7, np.pi / 2]) + @pytest.mark.parametrize("cost_fn, expected_res, expected_grad", cost_fns) + def test_grad_torch(self, x, cost_fn, expected_res, expected_grad, tol): + """Test gradients are correct for torch""" + x = torch.from_numpy(np.array(x)).requires_grad_() + res = cost_fn(x) + res.backward() + grad = x.grad + + assert qml.math.allclose(res, expected_res(x), tol) + assert qml.math.allclose(grad, expected_grad(x), tol) + + @pytest.mark.tf + @pytest.mark.parametrize("x", [0.0, 1e-7, 0.456, np.pi / 2 - 1e-7, np.pi / 2]) + @pytest.mark.parametrize("cost_fn, expected_res, expected_grad", cost_fns) + def test_grad_tf(self, x, cost_fn, expected_res, expected_grad, tol): + """Test gradients are correct for tf""" + x = tf.Variable(x, trainable=True) + + with tf.GradientTape() as tape: + res = cost_fn(x) + + grad = tape.gradient(res, x) + + assert qml.math.allclose(res, expected_res(x), tol) + assert qml.math.allclose(grad, expected_grad(x), tol) + + @pytest.mark.autograd + @pytest.mark.parametrize("cost_fn, expected_res, expected_grad", cost_fns) + def test_broadcast_autograd(self, cost_fn, expected_res, expected_grad, tol): + """Test gradients are correct for a broadcasted input for autograd""" + x = np.array([0.0, 1e-7, 0.456, np.pi / 2 - 1e-7, np.pi / 2]) + res = cost_fn(x) + grad = qml.math.diag(qml.jacobian(cost_fn)(x)) + + assert qml.math.allclose(res, expected_res(x), tol) + assert qml.math.allclose(grad, expected_grad(x), tol) + + @pytest.mark.jax + @pytest.mark.parametrize("cost_fn, expected_res, expected_grad", cost_fns) + def test_broadcast_jax(self, cost_fn, expected_res, expected_grad, tol): + """Test gradients are correct for a broadcasted input for jax""" + x = jnp.array([0.0, 1e-7, 0.456, np.pi / 2 - 1e-7, np.pi / 2]) + res = cost_fn(x) + grad = qml.math.diag(jax.jacobian(cost_fn)(x)) + + assert qml.math.allclose(res, expected_res(x), tol) + assert qml.math.allclose(grad, expected_grad(x), tol) + + @pytest.mark.jax + @pytest.mark.parametrize("cost_fn, expected_res, expected_grad", cost_fns) + def test_broadcast_jax_jit(self, cost_fn, expected_res, expected_grad, tol): + """Test gradients are correct for a broadcasted input for jax-jit""" + x = jnp.array([0.0, 1e-7, 0.456, np.pi / 2 - 1e-7, np.pi / 2]) + + jitted_cost = jax.jit(cost_fn) + res = jitted_cost(x) + grad = qml.math.diag(jax.jacobian(jitted_cost)(x)) + + assert qml.math.allclose(res, expected_res(x), tol) + assert qml.math.allclose(grad, expected_grad(x), tol) + + @pytest.mark.torch + @pytest.mark.parametrize("cost_fn, expected_res, expected_grad", cost_fns) + def test_broadcast_torch(self, cost_fn, expected_res, expected_grad, tol): + """Test gradients are correct for a broadcasted input for torch""" + x = torch.from_numpy( + np.array([0.0, 1e-7, 0.456, np.pi / 2 - 1e-7, np.pi / 2]) + ).requires_grad_() + + res = cost_fn(x) + grad = qml.math.diag(torch.autograd.functional.jacobian(cost_fn, x)) + + assert qml.math.allclose(res, expected_res(x), tol) + assert qml.math.allclose(grad, expected_grad(x), tol) + + @pytest.mark.tf + @pytest.mark.parametrize("cost_fn, expected_res, expected_grad", cost_fns) + def test_broadcast_tf(self, cost_fn, expected_res, expected_grad, tol): + """Test gradients are correct for a broadcasted input for tf""" + x = tf.Variable([0.0, 1e-7, 0.456, np.pi / 2 - 1e-7, np.pi / 2], trainable=True) + + with tf.GradientTape() as tape: + res = cost_fn(x) + + grad = tape.gradient(res, x) + + assert qml.math.allclose(res, expected_res(x), tol) + assert qml.math.allclose(grad, expected_grad(x), tol) diff --git a/tests/qinfo/test_fidelity.py b/tests/qinfo/test_fidelity.py index 4d4604b33cb..90c59386d00 100644 --- a/tests/qinfo/test_fidelity.py +++ b/tests/qinfo/test_fidelity.py @@ -462,6 +462,30 @@ def circuit1(x): fid_grad = tape.gradient(entropy, param) assert qml.math.allclose(fid_grad, expected_fid) + @pytest.mark.tf + @pytest.mark.parametrize("param", parameters) + @pytest.mark.parametrize("wire", wires) + @pytest.mark.parametrize("interface", interfaces) + def test_fidelity_qnodes_rx_tworx_tf_grad(self, param, wire, interface): + """Test the gradient of the fidelity between two trainable circuits with Tensorflow.""" + import tensorflow as tf + + dev = qml.device("default.qubit", wires=wire) + + @qml.qnode(dev, interface=interface, diff_method="backprop") + def circuit(x): + qml.RX(x, wires=0) + return qml.state() + + expected = expected_grad_fidelity_rx_pauliz(param) + expected_fid = [-expected, expected] + params = (tf.Variable(param), tf.Variable(2 * param)) + with tf.GradientTape() as tape: + fid = qml.qinfo.fidelity(circuit, circuit, wires0=[0], wires1=[0])(*params) + + fid_grad = tape.gradient(fid, params) + assert qml.math.allclose(fid_grad, expected_fid) + interfaces = ["jax"] @pytest.mark.jax