diff --git a/doc/releases/changelog-0.32.0.md b/doc/releases/changelog-0.32.0.md
index 72f3f9bc225..fa0b6ad2bd9 100644
--- a/doc/releases/changelog-0.32.0.md
+++ b/doc/releases/changelog-0.32.0.md
@@ -454,6 +454,9 @@ array([False, False])
[(4165)](https://github.com/PennyLaneAI/pennylane/pull/4165)
[(4482)](https://github.com/PennyLaneAI/pennylane/pull/4482)
+* The backprop gradient of `qml.math.fidelity` is now correct.
+ [(#4380)](https://github.com/PennyLaneAI/pennylane/pull/4380)
+
Contributors ✍️
diff --git a/pennylane/math/__init__.py b/pennylane/math/__init__.py
index e6ea3aec4f7..93b822f8289 100644
--- a/pennylane/math/__init__.py
+++ b/pennylane/math/__init__.py
@@ -67,8 +67,6 @@
from .quantum import (
cov_matrix,
dm_from_state_vector,
- fidelity,
- fidelity_statevector,
marginal_prob,
mutual_info,
purity,
@@ -80,6 +78,7 @@
max_entropy,
trace_distance,
)
+from .fidelity import fidelity, fidelity_statevector
from .utils import (
allclose,
allequal,
diff --git a/pennylane/math/fidelity.py b/pennylane/math/fidelity.py
new file mode 100644
index 00000000000..97989a085a2
--- /dev/null
+++ b/pennylane/math/fidelity.py
@@ -0,0 +1,357 @@
+# Copyright 2018-2023 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Contains the implementation of quantum fidelity.
+
+Note: care needs to be taken to make it fully differentiable. An explanation can
+be found in pennylane/math/fidelity_gradient.md
+"""
+from functools import lru_cache
+
+import autograd
+import autoray as ar
+import pennylane as qml
+
+from .utils import cast
+from .quantum import _check_density_matrix, _check_state_vector
+
+
+def fidelity_statevector(state0, state1, check_state=False, c_dtype="complex128"):
+ r"""Compute the fidelity for two states (given as state vectors) acting on quantum
+ systems with the same size.
+
+ The fidelity for two pure states given by state vectors :math:`\ket{\psi}` and :math:`\ket{\phi}`
+ is defined as
+
+ .. math::
+ F( \ket{\psi} , \ket{\phi}) = \left|\braket{\psi, \phi}\right|^2
+
+ This is faster than calling :func:`pennylane.math.fidelity` on the density matrix
+ representation of pure states.
+
+ .. note::
+ It supports all interfaces (Numpy, Autograd, Torch, Tensorflow and Jax). The second state is coerced
+ to the type and dtype of the first state. The fidelity is returned in the type of the interface of the
+ first state.
+
+ Args:
+ state0 (tensor_like): ``(2**N)`` or ``(batch_dim, 2**N)`` state vector.
+ state1 (tensor_like): ``(2**N)`` or ``(batch_dim, 2**N)`` state vector.
+ check_state (bool): If True, the function will check the validity of both states; that is,
+ the shape and the norm
+ c_dtype (str): Complex floating point precision type.
+
+ Returns:
+ float: Fidelity between the two quantum states.
+
+ **Example**
+
+ Two state vectors can be used as arguments and the fidelity (overlap) is returned, e.g.:
+
+ >>> state0 = [0.98753537-0.14925137j, 0.00746879-0.04941796j]
+ >>> state1 = [0.99500417+0.j, 0.09983342+0.j]
+ >>> qml.math.fidelity(state0, state1)
+ 0.9905158135644924
+
+ .. seealso:: :func:`pennylane.math.fidelity` and :func:`pennylane.qinfo.transforms.fidelity`
+
+ """
+ # Cast as a c_dtype array
+ state0 = cast(state0, dtype=c_dtype)
+ state1 = cast(state1, dtype=c_dtype)
+
+ if check_state:
+ _check_state_vector(state0)
+ _check_state_vector(state1)
+
+ if qml.math.shape(state0)[-1] != qml.math.shape(state1)[-1]:
+ raise qml.QuantumFunctionError("The two states must have the same number of wires.")
+
+ batched0 = len(qml.math.shape(state0)) > 1
+ batched1 = len(qml.math.shape(state1)) > 1
+
+ # Two pure states, squared overlap
+ indices0 = "ab" if batched0 else "b"
+ indices1 = "ab" if batched1 else "b"
+ target = "a" if batched0 or batched1 else ""
+ overlap = qml.math.einsum(
+ f"{indices0},{indices1}->{target}", state0, qml.math.conj(state1), optimize="greedy"
+ )
+
+ overlap = qml.math.abs(overlap) ** 2
+ return overlap
+
+
+def fidelity(state0, state1, check_state=False, c_dtype="complex128"):
+ r"""Compute the fidelity for two states (given as density matrices) acting on quantum
+ systems with the same size.
+
+ The fidelity for two mixed states given by density matrices :math:`\rho` and :math:`\sigma`
+ is defined as
+
+ .. math::
+ F( \rho , \sigma ) = \text{Tr}( \sqrt{\sqrt{\rho} \sigma \sqrt{\rho}})^2
+
+ .. note::
+ It supports all interfaces (Numpy, Autograd, Torch, Tensorflow and Jax). The second state is coerced
+ to the type and dtype of the first state. The fidelity is returned in the type of the interface of the
+ first state.
+
+ Args:
+ state0 (tensor_like): ``(2**N, 2**N)`` or ``(batch_dim, 2**N, 2**N)`` density matrix.
+ state1 (tensor_like): ``(2**N, 2**N)`` or ``(batch_dim, 2**N, 2**N)`` density matrix.
+ check_state (bool): If True, the function will check the validity of both states; that is,
+ (shape, trace, positive-definitiveness) for density matrices.
+ c_dtype (str): Complex floating point precision type.
+
+ Returns:
+ float: Fidelity between the two quantum states.
+
+ **Example**
+
+ To find the fidelity between two state vectors, call :func:`~.math.dm_from_state_vector` on the
+ inputs first, e.g.:
+
+ >>> state0 = qml.math.dm_from_state_vector([0.98753537-0.14925137j, 0.00746879-0.04941796j])
+ >>> state1 = qml.math.dm_from_state_vector([0.99500417+0.j, 0.09983342+0.j])
+ >>> qml.math.fidelity(state0, state1)
+ 0.9905158135644924
+
+ To find the fidelity between two density matrices, they can be passed directly:
+
+ >>> state0 = [[1, 0], [0, 0]]
+ >>> state1 = [[0, 0], [0, 1]]
+ >>> qml.math.fidelity(state0, state1)
+ 0.0
+
+ .. seealso:: :func:`pennylane.math.fidelity_statevector` and :func:`pennylane.qinfo.transforms.fidelity`
+
+ """
+ # Cast as a c_dtype array
+ state0 = cast(state0, dtype=c_dtype)
+ state1 = cast(state1, dtype=c_dtype)
+
+ if check_state:
+ _check_density_matrix(state0)
+ _check_density_matrix(state1)
+
+ if qml.math.shape(state0)[-1] != qml.math.shape(state1)[-1]:
+ raise qml.QuantumFunctionError("The two states must have the same number of wires.")
+
+ # Two mixed states
+ _register_vjp(state0, state1)
+ fid = qml.math.compute_fidelity(state0, state1)
+ return fid
+
+
+def _register_vjp(state0, state1):
+ """
+ Register the interface-specific custom VJP based on the interfaces of the given states
+
+ This function is needed because we don't want to register the custom
+ VJPs at PennyLane import time.
+ """
+ interface = qml.math.get_interface(state0, state1)
+ if interface == "jax":
+ _register_jax_vjp()
+ elif interface == "torch":
+ _register_torch_vjp()
+ elif interface == "tensorflow":
+ _register_tf_vjp()
+
+
+def _compute_fidelity_vanilla(density_matrix0, density_matrix1):
+ r"""Compute the fidelity for two density matrices with the same number of wires.
+
+ .. math::
+ F( \rho , \sigma ) = -\text{Tr}( \sqrt{\sqrt{\rho} \sigma \sqrt{\rho}})^2
+ """
+ # Implementation in single dispatches (sqrt(rho))
+ sqrt_mat = qml.math.sqrt_matrix(density_matrix0)
+
+ # sqrt(rho) * sigma * sqrt(rho)
+ sqrt_mat_sqrt = sqrt_mat @ density_matrix1 @ sqrt_mat
+
+ # extract eigenvalues
+ evs = qml.math.eigvalsh(sqrt_mat_sqrt)
+ evs = qml.math.real(evs)
+ evs = qml.math.where(evs > 0.0, evs, 0)
+
+ trace = (qml.math.sum(qml.math.sqrt(evs), -1)) ** 2
+
+ return trace
+
+
+def _compute_fidelity_vjp0(dm0, dm1, grad_out):
+ """
+ Compute the VJP of fidelity with respect to the first density matrix
+ """
+ # sqrt of sigma
+ sqrt_dm1 = qml.math.sqrt_matrix(dm1)
+
+ # eigendecomposition of sqrt(sigma) * rho * sqrt(sigma)
+ evs0, u0 = qml.math.linalg.eigh(sqrt_dm1 @ dm0 @ sqrt_dm1)
+ evs0 = qml.math.real(evs0)
+ evs0 = qml.math.where(evs0 > 1e-15, evs0, 1e-15)
+ evs0 = qml.math.cast_like(evs0, sqrt_dm1)
+
+ if len(qml.math.shape(dm0)) == 2 and len(qml.math.shape(dm1)) == 2:
+ u0_dag = qml.math.transpose(qml.math.conj(u0))
+ grad_dm0 = sqrt_dm1 @ u0 @ (1 / qml.math.sqrt(evs0)[..., None] * u0_dag) @ sqrt_dm1
+
+ # torch and tensorflow use the Wirtinger derivative which is a different convention
+ # than the one autograd and jax use for complex differentiation
+ if qml.math.get_interface(dm0) in ["torch", "tensorflow"]:
+ grad_dm0 = qml.math.sum(qml.math.sqrt(evs0), -1) * grad_dm0
+ else:
+ grad_dm0 = qml.math.sum(qml.math.sqrt(evs0), -1) * qml.math.transpose(grad_dm0)
+
+ res = grad_dm0 * qml.math.cast_like(grad_out, grad_dm0)
+ return res
+
+ # broadcasting case
+ u0_dag = qml.math.transpose(qml.math.conj(u0), (0, 2, 1))
+ grad_dm0 = sqrt_dm1 @ u0 @ (1 / qml.math.sqrt(evs0)[..., None] * u0_dag) @ sqrt_dm1
+
+ # torch and tensorflow use the Wirtinger derivative which is a different convention
+ # than the one autograd and jax use for complex differentiation
+ if qml.math.get_interface(dm0) in ["torch", "tensorflow"]:
+ grad_dm0 = qml.math.sum(qml.math.sqrt(evs0), -1)[:, None, None] * grad_dm0
+ else:
+ grad_dm0 = qml.math.sum(qml.math.sqrt(evs0), -1)[:, None, None] * qml.math.transpose(
+ grad_dm0, (0, 2, 1)
+ )
+
+ return grad_dm0 * qml.math.cast_like(grad_out, grad_dm0)[:, None, None]
+
+
+def _compute_fidelity_vjp1(dm0, dm1, grad_out):
+ """
+ Compute the VJP of fidelity with respect to the second density matrix
+ """
+ # pylint: disable=arguments-out-of-order
+ return _compute_fidelity_vjp0(dm1, dm0, grad_out)
+
+
+def _compute_fidelity_grad(dm0, dm1, grad_out):
+ return _compute_fidelity_vjp0(dm0, dm1, grad_out), _compute_fidelity_vjp1(dm0, dm1, grad_out)
+
+
+################################ numpy ###################################
+
+ar.register_function("numpy", "compute_fidelity", _compute_fidelity_vanilla)
+
+################################ autograd ################################
+
+
+@autograd.extend.primitive
+def _compute_fidelity_autograd(dm0, dm1):
+ return _compute_fidelity_vanilla(dm0, dm1)
+
+
+def _compute_fidelity_autograd_vjp0(_, dm0, dm1):
+ def vjp(grad_out):
+ return _compute_fidelity_vjp0(dm0, dm1, grad_out)
+
+ return vjp
+
+
+def _compute_fidelity_autograd_vjp1(_, dm0, dm1):
+ def vjp(grad_out):
+ return _compute_fidelity_vjp1(dm0, dm1, grad_out)
+
+ return vjp
+
+
+autograd.extend.defvjp(
+ _compute_fidelity_autograd, _compute_fidelity_autograd_vjp0, _compute_fidelity_autograd_vjp1
+)
+ar.register_function("autograd", "compute_fidelity", _compute_fidelity_autograd)
+
+################################# jax #####################################
+
+
+@lru_cache(maxsize=None)
+def _register_jax_vjp():
+ """
+ Register the custom VJP for JAX
+ """
+ # pylint: disable=import-outside-toplevel
+ import jax
+
+ @jax.custom_vjp
+ def _compute_fidelity_jax(dm0, dm1):
+ return _compute_fidelity_vanilla(dm0, dm1)
+
+ def _compute_fidelity_jax_fwd(dm0, dm1):
+ fid = _compute_fidelity_jax(dm0, dm1)
+ return fid, (dm0, dm1)
+
+ def _compute_fidelity_jax_bwd(res, grad_out):
+ dm0, dm1 = res
+ return _compute_fidelity_grad(dm0, dm1, grad_out)
+
+ _compute_fidelity_jax.defvjp(_compute_fidelity_jax_fwd, _compute_fidelity_jax_bwd)
+ ar.register_function("jax", "compute_fidelity", _compute_fidelity_jax)
+
+
+################################ torch ###################################
+
+
+@lru_cache(maxsize=None)
+def _register_torch_vjp():
+ """
+ Register the custom VJP for torch
+ """
+ # pylint: disable=import-outside-toplevel
+ import torch
+
+ class _TorchFidelity(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, dm0, dm1):
+ """Forward pass for _compute_fidelity"""
+ fid = _compute_fidelity_vanilla(dm0, dm1)
+ ctx.save_for_backward(dm0, dm1)
+ return fid
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ """Backward pass for _compute_fidelity"""
+ dm0, dm1 = ctx.saved_tensors
+ return _compute_fidelity_grad(dm0, dm1, grad_out)
+
+ ar.register_function("torch", "compute_fidelity", _TorchFidelity.apply)
+
+
+############################### tensorflow ################################
+
+
+@lru_cache(maxsize=None)
+def _register_tf_vjp():
+ """
+ Register the custom VJP for tensorflow
+ """
+ # pylint: disable=import-outside-toplevel
+ import tensorflow as tf
+
+ @tf.custom_gradient
+ def _compute_fidelity_tf(dm0, dm1):
+ fid = _compute_fidelity_vanilla(dm0, dm1)
+
+ def vjp(grad_out):
+ return _compute_fidelity_grad(dm0, dm1, grad_out)
+
+ return fid, vjp
+
+ ar.register_function("tensorflow", "compute_fidelity", _compute_fidelity_tf)
diff --git a/pennylane/math/fidelity_gradient.md b/pennylane/math/fidelity_gradient.md
new file mode 100644
index 00000000000..603eb45260a
--- /dev/null
+++ b/pennylane/math/fidelity_gradient.md
@@ -0,0 +1,73 @@
+
+# Gradient of fidelity
+
+## Issue with autodiff
+The fidelity between two density operators $\rho$ and $\sigma$ is given by
+
+$$\begin{align}
+F(\rho,\sigma)=\text{Tr}\left(\sqrt{\sqrt{\rho}\\,\sigma\\,\sqrt{\rho}}\right)^2\qquad\qquad\qquad\text{(1)}
+\end{align}$$
+
+Looking at the above, we multiply the square root of the operator $\rho$ with a generally non-commuting operator $\sigma$, so it is clear that computing $F$ requires the full eigendecomposition of $\rho$, including its eigenvectors. For comparison, the trace distance between $\rho$ and $\sigma$ is given by
+
+$$\begin{align*}
+T(\rho,\sigma)=\frac{1}{2}\text{Tr}\left(|\rho-\sigma|\right)
+\end{align*}$$
+
+which does not require the eigenvectors of $\rho$, $\sigma$, or any function of the two, since the trace can be directly computed from the eigenvalues of $\rho-\sigma$.
+
+The backward gradient of the eigendecomposition $\rho=U\Lambda U^\dagger$ is given by
+
+$$\frac{\mathcal{\partial L}}{\partial\rho}=U^*\left(\frac{\partial\mathcal{L}}{\partial \Lambda}+D\circ U^\dagger\frac{\partial\mathcal{L}}{\partial U}\right)U^T$$
+
+where $D$ is the matrix defined as
+
+$$D_{ij}=\begin{cases}\frac{1}{\lambda_j-\lambda_i}\quad&\text{if }i\neq j\\
+0&\text{if }i=j\end{cases}$$
+
+See section 4.17 [here](https://arxiv.org/pdf/1701.00392.pdf) for more details.
+
+It is clear that if $\rho$ has two or more eigenvalues which are the same, then the gradient of the eigendecomposition will be undefined. Indeed, auto-differentiation frameworks like JAX produce NaN results when backpropagating through $F$ if the inputs $\rho$ or $\sigma$ are sparse (which is often the case for states close to pure). On the other hand, the fidelity is a measure of overlap and thus has a well-defined gradient in every scenario.
+
+Our solution to the above is to skip the gradient computation of the eigendecomposition, and instead treat the fidelity function as a black-box while defining a custom gradient.
+
+## Gradient derivation
+
+Let $\sqrt{\rho}\\,\sigma\sqrt{\rho}=U\Lambda U^\dagger$ be the eigendecomposition of $\sqrt{\rho}\\,\sigma\sqrt{\rho}$. In einsum notation, we have
+
+$$\Lambda_{k\ell} = U^\dagger_ {kp}\sqrt{\rho}_ {pm}\sigma_{mn}\sqrt{\rho}_ {nq}U_ {q\ell}$$
+
+whence
+
+$$\frac{\partial\Lambda_{k\ell}}{\partial \sigma_ {mn}}=\sqrt{\rho}_ {nq}U_{q\ell}U^\dagger_{kp}\sqrt{\rho}_{pm}$$
+
+Then equation (1) gives
+
+$$F(\rho,\sigma)=\text{Tr}\left(\sqrt{\sqrt{\rho}\\,\sigma\\,\sqrt{\rho}}\right)^2=\left(\sum_{k=1}^d\sqrt{\Lambda_{kk}}\right)^2$$
+
+and the gradient is easily seen to be
+
+$$\begin{align*}
+\frac{\partial F}{\partial\sigma_{mn}}&=2\left(\sum_{k=1}^n\sqrt{\Lambda_{kk}}\right)\sum_{k=1}^d\sqrt{\rho}_ {nq}U_{qk}\frac{1}{2\sqrt{\Lambda_{kk}}}U^\dagger_{kp}\sqrt{\rho}_{pm}\\
+\end{align*}$$
+
+Converting back to matrix notation, this is
+
+$$\begin{align}
+\frac{\partial F}{\partial\sigma}=\sqrt{F(\rho,\sigma)}\left(\sqrt{\rho}\\,\left(\sqrt{\rho}\\,\sigma\sqrt{\rho}\right)^{-\frac{1}{2}}\sqrt{\rho}\right)^T\qquad\qquad\qquad\text{(2)}
+\end{align}$$
+
+Since $F(\rho,\sigma)=F(\sigma,\rho)$, it follows immediately that
+
+$$\begin{align}
+\frac{\partial F}{\partial\rho}=\sqrt{F(\rho,\sigma)}\left(\sqrt{\sigma}\\,\left(\sqrt{\sigma}\\,\rho\sqrt{\sigma}\right)^{-\frac{1}{2}}\sqrt{\sigma}\right)^T\qquad\qquad\qquad\text{(3)}
+\end{align}$$
+
+Equations (2) and (3) are what we use for the custom gradient of the fidelity function.
+
+Note that these equations are what Autograd and JAX use for the gradient of a real-valued function defined on complex inputs. On the other hand, PyTorch and TensorFlow use the Wirtinger derivatives given by
+
+$$\begin{align}
+\frac{\partial F}{\partial\rho^\*}&=\sqrt{F(\rho,\sigma)}\sqrt{\sigma}\\,\left(\sqrt{\sigma}\\,\rho\sqrt{\sigma}\right)^{-\frac{1}{2}}\sqrt{\sigma}\\
+\frac{\partial F}{\partial\sigma^\*}&=\sqrt{F(\rho,\sigma)}\sqrt{\rho}\\,\left(\sqrt{\rho}\\,\sigma\sqrt{\rho}\right)^{-\frac{1}{2}}\sqrt{\rho}
+\end{align}$$
diff --git a/pennylane/math/quantum.py b/pennylane/math/quantum.py
index f7020c89b6f..a27f660ed48 100644
--- a/pennylane/math/quantum.py
+++ b/pennylane/math/quantum.py
@@ -720,139 +720,6 @@ def _compute_mutual_info(
return vn_entropy_1 + vn_entropy_2 - vn_entropy_12
-def fidelity(state0, state1, check_state=False, c_dtype="complex128"):
- r"""Compute the fidelity for two states (given as density matrices) acting on quantum
- systems with the same size.
-
- The fidelity for two mixed states given by density matrices :math:`\rho` and :math:`\sigma`
- is defined as
-
- .. math::
- F( \rho , \sigma ) = \text{Tr}( \sqrt{\sqrt{\rho} \sigma \sqrt{\rho}})^2
-
- .. note::
- It supports all interfaces (Numpy, Autograd, Torch, Tensorflow and Jax). The second state is coerced
- to the type and dtype of the first state. The fidelity is returned in the type of the interface of the
- first state.
-
- Args:
- state0 (tensor_like): ``(2**N, 2**N)`` or ``(batch_dim, 2**N, 2**N)`` density matrix.
- state1 (tensor_like): ``(2**N, 2**N)`` or ``(batch_dim, 2**N, 2**N)`` density matrix.
- check_state (bool): If True, the function will check the validity of both states; that is,
- (shape, trace, positive-definitiveness) for density matrices.
- c_dtype (str): Complex floating point precision type.
-
- Returns:
- float: Fidelity between the two quantum states.
-
- **Example**
-
- To find the fidelity between two state vectors, call :func:`~.math.dm_from_state_vector` on the
- inputs first, e.g.:
-
- >>> state0 = qml.math.dm_from_state_vector([0.98753537-0.14925137j, 0.00746879-0.04941796j])
- >>> state1 = qml.math.dm_from_state_vector([0.99500417+0.j, 0.09983342+0.j])
- >>> qml.math.fidelity(state0, state1)
- 0.9905158135644924
-
- To find the fidelity between two density matrices, they can be passed directly:
-
- >>> state0 = [[1, 0], [0, 0]]
- >>> state1 = [[0, 0], [0, 1]]
- >>> qml.math.fidelity(state0, state1)
- 0.0
-
- .. seealso:: :func:`pennylane.math.fidelity_statevector` and :func:`pennylane.qinfo.transforms.fidelity`
-
- """
- # Cast as a c_dtype array
- state0 = cast(state0, dtype=c_dtype)
-
- # Cannot be cast_like if jit
- if not is_abstract(state0):
- state1 = cast_like(state1, state0)
-
- if check_state:
- _check_density_matrix(state0)
- _check_density_matrix(state1)
-
- if qml.math.shape(state0)[-1] != qml.math.shape(state1)[-1]:
- raise qml.QuantumFunctionError("The two states must have the same number of wires.")
-
- # Two mixed states
- fid = _compute_fidelity(state0, state1)
- return fid
-
-
-def fidelity_statevector(state0, state1, check_state=False, c_dtype="complex128"):
- r"""Compute the fidelity for two states (given as state vectors) acting on quantum
- systems with the same size.
-
- The fidelity for two pure states given by state vectors :math:`\ket{\psi}` and :math:`\ket{\phi}`
- is defined as
-
- .. math::
- F( \ket{\psi} , \ket{\phi}) = \left|\braket{\psi, \phi}\right|^2
-
- This is faster than calling :func:`pennylane.math.fidelity` on the density matrix
- representation of pure states.
-
- .. note::
- It supports all interfaces (Numpy, Autograd, Torch, Tensorflow and Jax). The second state is coerced
- to the type and dtype of the first state. The fidelity is returned in the type of the interface of the
- first state.
-
- Args:
- state0 (tensor_like): ``(2**N)`` or ``(batch_dim, 2**N)`` state vector.
- state1 (tensor_like): ``(2**N)`` or ``(batch_dim, 2**N)`` state vector.
- check_state (bool): If True, the function will check the validity of both states; that is,
- the shape and the norm
- c_dtype (str): Complex floating point precision type.
-
- Returns:
- float: Fidelity between the two quantum states.
-
- **Example**
-
- Two state vectors can be used as arguments and the fidelity (overlap) is returned, e.g.:
-
- >>> state0 = [0.98753537-0.14925137j, 0.00746879-0.04941796j]
- >>> state1 = [0.99500417+0.j, 0.09983342+0.j]
- >>> qml.math.fidelity(state0, state1)
- 0.9905158135644924
-
- .. seealso:: :func:`pennylane.math.fidelity` and :func:`pennylane.qinfo.transforms.fidelity`
-
- """
- # Cast as a c_dtype array
- state0 = cast(state0, dtype=c_dtype)
-
- # Cannot be cast_like if jit
- if not is_abstract(state0):
- state1 = cast_like(state1, state0)
-
- if check_state:
- _check_state_vector(state0)
- _check_state_vector(state1)
-
- if qml.math.shape(state0)[-1] != qml.math.shape(state1)[-1]:
- raise qml.QuantumFunctionError("The two states must have the same number of wires.")
-
- batched0 = len(qml.math.shape(state0)) > 1
- batched1 = len(qml.math.shape(state1)) > 1
-
- # Two pure states, squared overlap
- indices0 = "ab" if batched0 else "b"
- indices1 = "ab" if batched1 else "b"
- target = "a" if batched0 or batched1 else ""
- overlap = qml.math.einsum(
- f"{indices0},{indices1}->{target}", state0, np.conj(state1), optimize="greedy"
- )
-
- overlap = np.abs(overlap) ** 2
- return overlap
-
-
def sqrt_matrix(density_matrix):
r"""Compute the square root matrix of a density matrix where :math:`\rho = \sqrt{\rho} \times \sqrt{\rho}`
@@ -863,7 +730,7 @@ def sqrt_matrix(density_matrix):
(tensor_like): Square root of the density matrix.
"""
evs, vecs = qml.math.linalg.eigh(density_matrix)
- evs = np.real(evs)
+ evs = qml.math.real(evs)
evs = qml.math.where(evs > 0.0, evs, 0.0)
if not is_abstract(evs):
evs = qml.math.cast_like(evs, vecs)
@@ -871,32 +738,11 @@ def sqrt_matrix(density_matrix):
shape = qml.math.shape(density_matrix)
if len(shape) > 2:
# broadcasting case
- sqrt_evs = qml.math.expand_dims(qml.math.sqrt(evs), 1) * qml.math.eye(shape[-1])
+ i = qml.math.cast_like(qml.math.convert_like(qml.math.eye(shape[-1]), evs), evs)
+ sqrt_evs = qml.math.expand_dims(qml.math.sqrt(evs), 1) * i
return vecs @ sqrt_evs @ qml.math.conj(qml.math.transpose(vecs, (0, 2, 1)))
- return vecs @ qml.math.diag(np.sqrt(evs)) @ np.conj(np.transpose(vecs))
-
-
-def _compute_fidelity(density_matrix0, density_matrix1):
- r"""Compute the fidelity for two density matrices with the same number of wires.
-
- .. math::
- F( \rho , \sigma ) = -\text{Tr}( \sqrt{\sqrt{\rho} \sigma \sqrt{\rho}})^2
- """
- # Implementation in single dispatches (sqrt(rho))
- sqrt_mat = qml.math.sqrt_matrix(density_matrix0)
-
- # sqrt(rho) * sigma * sqrt(rho)
- sqrt_mat_sqrt = sqrt_mat @ density_matrix1 @ sqrt_mat
-
- # extract eigenvalues
- evs = qml.math.eigvalsh(sqrt_mat_sqrt)
- evs = np.real(evs)
- evs = qml.math.where(evs > 0.0, evs, 0.0)
-
- trace = (qml.math.sum(qml.math.sqrt(evs), -1)) ** 2
-
- return trace
+ return vecs @ qml.math.diag(qml.math.sqrt(evs)) @ qml.math.conj(qml.math.transpose(vecs))
def _compute_relative_entropy(rho, sigma, base=None):
diff --git a/tests/math/test_fidelity_math.py b/tests/math/test_fidelity_math.py
index 1e3db353ac3..da4a38df49a 100644
--- a/tests/math/test_fidelity_math.py
+++ b/tests/math/test_fidelity_math.py
@@ -216,3 +216,223 @@ def test_broadcast_dm_dm_unbatched(self, check_state, func):
fidelity = qml.math.fidelity(state0, state1, check_state)
assert qml.math.allclose(fidelity, expected)
+
+
+def cost_fn_single(x):
+ first_term = qml.math.convert_like(qml.math.diag([1.0, 0]), x)
+ second_term = qml.math.convert_like(qml.math.diag([0, 1.0]), x)
+
+ x = qml.math.cast_like(x, first_term)
+ if len(qml.math.shape(x)) == 0:
+ state1 = qml.math.cos(x / 2) ** 2 * first_term + qml.math.sin(x / 2) ** 2 * second_term
+ else:
+ # broadcasting
+ x = x[:, None, None]
+ state1 = qml.math.cos(x / 2) ** 2 * first_term + qml.math.sin(x / 2) ** 2 * second_term
+
+ state2 = qml.math.convert_like(qml.math.diag([1, 0]), state1)
+
+ return qml.math.fidelity(state1, state2) + qml.math.fidelity(state2, state1)
+
+
+def cost_fn_multi1(x):
+ first_term = qml.math.convert_like(qml.math.diag([1.0, 0, 0, 0]), x)
+ second_term = qml.math.convert_like(qml.math.diag([0, 0, 0, 1.0]), x)
+
+ x = qml.math.cast_like(x, first_term)
+
+ if len(qml.math.shape(x)) == 0:
+ state1 = qml.math.cos(x / 2) ** 2 * first_term + qml.math.sin(x / 2) ** 2 * second_term
+ else:
+ # broadcasting
+ x = x[:, None, None]
+ state1 = qml.math.cos(x / 2) ** 2 * first_term + qml.math.sin(x / 2) ** 2 * second_term
+
+ state2 = qml.math.convert_like(qml.math.diag([1, 0, 0, 0]), state1)
+
+ return qml.math.fidelity(state1, state2) + qml.math.fidelity(state2, state1)
+
+
+def cost_fn_multi2(x):
+ first_term = qml.math.convert_like(np.ones((4, 4)) / 4, x)
+ second_term = np.zeros((4, 4))
+ second_term[1:3, 1:3] = np.array([[1, -1], [-1, 1]]) / 2
+ second_term = qml.math.convert_like(second_term, x)
+
+ x = qml.math.cast_like(x, first_term)
+
+ if len(qml.math.shape(x)) == 0:
+ state1 = qml.math.cos(x / 2) ** 2 * first_term + qml.math.sin(x / 2) ** 2 * second_term
+ else:
+ # broadcasting
+ x = x[:, None, None]
+ state1 = qml.math.cos(x / 2) ** 2 * first_term + qml.math.sin(x / 2) ** 2 * second_term
+
+ state2 = qml.math.convert_like(qml.math.diag([1, 0, 0, 0]), state1)
+
+ return qml.math.fidelity(state1, state2) + qml.math.fidelity(state2, state1)
+
+
+def expected_res_single(x):
+ return 2 * qml.math.cos(x / 2) ** 2
+
+
+def expected_res_multi1(x):
+ return 2 * qml.math.cos(x / 2) ** 2
+
+
+def expected_res_multi2(x):
+ return qml.math.cos(x / 2) ** 2 / 2
+
+
+def expected_grad_single(x):
+ return -qml.math.sin(x)
+
+
+def expected_grad_multi1(x):
+ return -qml.math.sin(x)
+
+
+def expected_grad_multi2(x):
+ return -qml.math.sin(x) / 4
+
+
+class TestGradient:
+ """Test the gradient of qml.math.fidelity"""
+
+ # pylint: disable=too-many-arguments
+
+ cost_fns = [
+ (cost_fn_single, expected_res_single, expected_grad_single),
+ (cost_fn_multi1, expected_res_multi1, expected_grad_multi1),
+ (cost_fn_multi2, expected_res_multi2, expected_grad_multi2),
+ ]
+
+ @pytest.mark.autograd
+ @pytest.mark.parametrize("x", [0.0, 1e-7, 0.456, np.pi / 2 - 1e-7, np.pi / 2])
+ @pytest.mark.parametrize("cost_fn, expected_res, expected_grad", cost_fns)
+ def test_grad_autograd(self, x, cost_fn, expected_res, expected_grad, tol):
+ """Test gradients are correct for autograd"""
+ x = np.array(x)
+ res = cost_fn(x)
+ grad = qml.grad(cost_fn)(x)
+
+ assert qml.math.allclose(res, expected_res(x), tol)
+ assert qml.math.allclose(grad, expected_grad(x), tol)
+
+ @pytest.mark.jax
+ @pytest.mark.parametrize("x", [0.0, 1e-7, 0.456, np.pi / 2 - 1e-7, np.pi / 2])
+ @pytest.mark.parametrize("cost_fn, expected_res, expected_grad", cost_fns)
+ def test_grad_jax(self, x, cost_fn, expected_res, expected_grad, tol):
+ """Test gradients are correct for jax"""
+ x = jnp.array(x)
+ res = cost_fn(x)
+ grad = jax.grad(cost_fn)(x)
+
+ assert qml.math.allclose(res, expected_res(x), tol)
+ assert qml.math.allclose(grad, expected_grad(x), tol)
+
+ @pytest.mark.jax
+ @pytest.mark.parametrize("x", [0.0, 1e-7, 0.456, np.pi / 2 - 1e-7, np.pi / 2])
+ @pytest.mark.parametrize("cost_fn, expected_res, expected_grad", cost_fns)
+ def test_grad_jax_jit(self, x, cost_fn, expected_res, expected_grad, tol):
+ """Test gradients are correct for jax-jit"""
+ x = jnp.array(x)
+
+ jitted_cost = jax.jit(cost_fn)
+ res = jitted_cost(x)
+ grad = jax.grad(jitted_cost)(x)
+
+ assert qml.math.allclose(res, expected_res(x), tol)
+ assert qml.math.allclose(grad, expected_grad(x), tol)
+
+ @pytest.mark.torch
+ @pytest.mark.parametrize("x", [0.0, 1e-7, 0.456, np.pi / 2 - 1e-7, np.pi / 2])
+ @pytest.mark.parametrize("cost_fn, expected_res, expected_grad", cost_fns)
+ def test_grad_torch(self, x, cost_fn, expected_res, expected_grad, tol):
+ """Test gradients are correct for torch"""
+ x = torch.from_numpy(np.array(x)).requires_grad_()
+ res = cost_fn(x)
+ res.backward()
+ grad = x.grad
+
+ assert qml.math.allclose(res, expected_res(x), tol)
+ assert qml.math.allclose(grad, expected_grad(x), tol)
+
+ @pytest.mark.tf
+ @pytest.mark.parametrize("x", [0.0, 1e-7, 0.456, np.pi / 2 - 1e-7, np.pi / 2])
+ @pytest.mark.parametrize("cost_fn, expected_res, expected_grad", cost_fns)
+ def test_grad_tf(self, x, cost_fn, expected_res, expected_grad, tol):
+ """Test gradients are correct for tf"""
+ x = tf.Variable(x, trainable=True)
+
+ with tf.GradientTape() as tape:
+ res = cost_fn(x)
+
+ grad = tape.gradient(res, x)
+
+ assert qml.math.allclose(res, expected_res(x), tol)
+ assert qml.math.allclose(grad, expected_grad(x), tol)
+
+ @pytest.mark.autograd
+ @pytest.mark.parametrize("cost_fn, expected_res, expected_grad", cost_fns)
+ def test_broadcast_autograd(self, cost_fn, expected_res, expected_grad, tol):
+ """Test gradients are correct for a broadcasted input for autograd"""
+ x = np.array([0.0, 1e-7, 0.456, np.pi / 2 - 1e-7, np.pi / 2])
+ res = cost_fn(x)
+ grad = qml.math.diag(qml.jacobian(cost_fn)(x))
+
+ assert qml.math.allclose(res, expected_res(x), tol)
+ assert qml.math.allclose(grad, expected_grad(x), tol)
+
+ @pytest.mark.jax
+ @pytest.mark.parametrize("cost_fn, expected_res, expected_grad", cost_fns)
+ def test_broadcast_jax(self, cost_fn, expected_res, expected_grad, tol):
+ """Test gradients are correct for a broadcasted input for jax"""
+ x = jnp.array([0.0, 1e-7, 0.456, np.pi / 2 - 1e-7, np.pi / 2])
+ res = cost_fn(x)
+ grad = qml.math.diag(jax.jacobian(cost_fn)(x))
+
+ assert qml.math.allclose(res, expected_res(x), tol)
+ assert qml.math.allclose(grad, expected_grad(x), tol)
+
+ @pytest.mark.jax
+ @pytest.mark.parametrize("cost_fn, expected_res, expected_grad", cost_fns)
+ def test_broadcast_jax_jit(self, cost_fn, expected_res, expected_grad, tol):
+ """Test gradients are correct for a broadcasted input for jax-jit"""
+ x = jnp.array([0.0, 1e-7, 0.456, np.pi / 2 - 1e-7, np.pi / 2])
+
+ jitted_cost = jax.jit(cost_fn)
+ res = jitted_cost(x)
+ grad = qml.math.diag(jax.jacobian(jitted_cost)(x))
+
+ assert qml.math.allclose(res, expected_res(x), tol)
+ assert qml.math.allclose(grad, expected_grad(x), tol)
+
+ @pytest.mark.torch
+ @pytest.mark.parametrize("cost_fn, expected_res, expected_grad", cost_fns)
+ def test_broadcast_torch(self, cost_fn, expected_res, expected_grad, tol):
+ """Test gradients are correct for a broadcasted input for torch"""
+ x = torch.from_numpy(
+ np.array([0.0, 1e-7, 0.456, np.pi / 2 - 1e-7, np.pi / 2])
+ ).requires_grad_()
+
+ res = cost_fn(x)
+ grad = qml.math.diag(torch.autograd.functional.jacobian(cost_fn, x))
+
+ assert qml.math.allclose(res, expected_res(x), tol)
+ assert qml.math.allclose(grad, expected_grad(x), tol)
+
+ @pytest.mark.tf
+ @pytest.mark.parametrize("cost_fn, expected_res, expected_grad", cost_fns)
+ def test_broadcast_tf(self, cost_fn, expected_res, expected_grad, tol):
+ """Test gradients are correct for a broadcasted input for tf"""
+ x = tf.Variable([0.0, 1e-7, 0.456, np.pi / 2 - 1e-7, np.pi / 2], trainable=True)
+
+ with tf.GradientTape() as tape:
+ res = cost_fn(x)
+
+ grad = tape.gradient(res, x)
+
+ assert qml.math.allclose(res, expected_res(x), tol)
+ assert qml.math.allclose(grad, expected_grad(x), tol)
diff --git a/tests/qinfo/test_fidelity.py b/tests/qinfo/test_fidelity.py
index 4d4604b33cb..90c59386d00 100644
--- a/tests/qinfo/test_fidelity.py
+++ b/tests/qinfo/test_fidelity.py
@@ -462,6 +462,30 @@ def circuit1(x):
fid_grad = tape.gradient(entropy, param)
assert qml.math.allclose(fid_grad, expected_fid)
+ @pytest.mark.tf
+ @pytest.mark.parametrize("param", parameters)
+ @pytest.mark.parametrize("wire", wires)
+ @pytest.mark.parametrize("interface", interfaces)
+ def test_fidelity_qnodes_rx_tworx_tf_grad(self, param, wire, interface):
+ """Test the gradient of the fidelity between two trainable circuits with Tensorflow."""
+ import tensorflow as tf
+
+ dev = qml.device("default.qubit", wires=wire)
+
+ @qml.qnode(dev, interface=interface, diff_method="backprop")
+ def circuit(x):
+ qml.RX(x, wires=0)
+ return qml.state()
+
+ expected = expected_grad_fidelity_rx_pauliz(param)
+ expected_fid = [-expected, expected]
+ params = (tf.Variable(param), tf.Variable(2 * param))
+ with tf.GradientTape() as tape:
+ fid = qml.qinfo.fidelity(circuit, circuit, wires0=[0], wires1=[0])(*params)
+
+ fid_grad = tape.gradient(fid, params)
+ assert qml.math.allclose(fid_grad, expected_fid)
+
interfaces = ["jax"]
@pytest.mark.jax