Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUGFIX] Fix NaN values in gradient of qml.math.fidelity #4380

Merged
merged 39 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
5987479
Fix grad of sqrt_matrix
eddddddy Jul 21, 2023
a4a1f3b
fix broadcasting
eddddddy Jul 21, 2023
1d491dd
Add more tests
eddddddy Jul 27, 2023
58970f2
squash
eddddddy Jul 27, 2023
8ed082d
fix import errors
eddddddy Jul 27, 2023
01f5075
Fix jax
eddddddy Jul 27, 2023
2e0edfe
Fix eigh for jax and torch
eddddddy Jul 28, 2023
e4e8abe
Merge branch 'master' into fidelity_fix
eddddddy Jul 31, 2023
2124e0e
Fix jax.ad deprecation. (#4403)
vincentmr Jul 31, 2023
34782b4
fix `has_decomposition` for ControlledQubitUnitary (#4407)
timmysilv Aug 1, 2023
c1e51f4
[sc-36527]: Add new robots.txt to doc build to hide latest build from…
rashidnhm Aug 1, 2023
3978887
Adding a `wire_order` kwarg to `Tensor.sparse_matrix()` (#4424)
BorjaRequena Aug 2, 2023
ed3a358
Fix `split_non_commuting` when tape contains both `expval` and `var` …
eddddddy Aug 2, 2023
7e29298
Adds shots to experimental device interface and integrate with QNode …
albi3ro Aug 2, 2023
3969eb5
QNSPSA bugfix (#4421)
albi3ro Aug 2, 2023
b5fdeac
Merge `v0.31.1-rc0` branch into master (#4428)
eddddddy Aug 2, 2023
b66d525
Integrate `TransformProgram` with `QNode` (#4404)
albi3ro Aug 3, 2023
fd06f61
move multiprocessing pre-processing to preprocess (#4425)
timmysilv Aug 3, 2023
e006778
Add jit tests
eddddddy Aug 3, 2023
0cc2cf1
Merge branch 'master' into fidelity_fix
eddddddy Aug 3, 2023
4994940
custom vjp for fidelity
eddddddy Aug 15, 2023
25c3607
Merge branch 'master' into fidelity_fix
eddddddy Aug 15, 2023
1ffd569
Update pennylane/math/quantum.py
eddddddy Aug 16, 2023
93caad7
pylint
eddddddy Aug 16, 2023
290ca7d
more pylint
eddddddy Aug 16, 2023
2be9d8e
import at runtime
eddddddy Aug 17, 2023
02a7634
combine single and multi tests
eddddddy Aug 18, 2023
424adb5
pylint
eddddddy Aug 18, 2023
a6c43ca
Merge branch 'master' into fidelity_fix
eddddddy Aug 18, 2023
75f505b
Merge branch 'master' into fidelity_fix
eddddddy Aug 18, 2023
58e7aea
changelog
eddddddy Aug 18, 2023
348eb26
Merge branch 'master' into fidelity_fix
eddddddy Aug 21, 2023
08948b9
Add docs
eddddddy Aug 22, 2023
256af4e
Merge branch 'v0.32.0-rc0' into fidelity_fix
eddddddy Aug 22, 2023
bf9cdae
Rendering changes
eddddddy Aug 22, 2023
2c6e615
Update pennylane/math/fidelity.py
eddddddy Aug 22, 2023
70b0349
address some comments
eddddddy Aug 23, 2023
766e36f
Merge branch 'v0.32.0-rc0' into fidelity_fix
eddddddy Aug 23, 2023
e3fce19
Merge branch 'v0.32.0-rc0' into fidelity_fix
eddddddy Aug 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pennylane/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@
from .quantum import (
cov_matrix,
dm_from_state_vector,
fidelity,
fidelity_statevector,
marginal_prob,
mutual_info,
purity,
Expand All @@ -80,6 +78,7 @@
max_entropy,
trace_distance,
)
from .fidelity import fidelity, fidelity_statevector
from .utils import (
allclose,
allequal,
Expand Down
331 changes: 331 additions & 0 deletions pennylane/math/fidelity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,331 @@
# Copyright 2018-2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains the implementation of quantum fidelity.

Note: care needs to be taken to make it fully differentiable
"""

import autograd
import autoray as ar
import pennylane as qml

from .utils import cast
from .quantum import _check_density_matrix, _check_state_vector


def fidelity(state0, state1, check_state=False, c_dtype="complex128"):
r"""Compute the fidelity for two states (given as density matrices) acting on quantum
systems with the same size.

The fidelity for two mixed states given by density matrices :math:`\rho` and :math:`\sigma`
is defined as

.. math::
F( \rho , \sigma ) = \text{Tr}( \sqrt{\sqrt{\rho} \sigma \sqrt{\rho}})^2

.. note::
It supports all interfaces (Numpy, Autograd, Torch, Tensorflow and Jax). The second state is coerced
to the type and dtype of the first state. The fidelity is returned in the type of the interface of the
first state.

Args:
state0 (tensor_like): ``(2**N, 2**N)`` or ``(batch_dim, 2**N, 2**N)`` density matrix.
state1 (tensor_like): ``(2**N, 2**N)`` or ``(batch_dim, 2**N, 2**N)`` density matrix.
check_state (bool): If True, the function will check the validity of both states; that is,
(shape, trace, positive-definitiveness) for density matrices.
c_dtype (str): Complex floating point precision type.

Returns:
float: Fidelity between the two quantum states.

**Example**

To find the fidelity between two state vectors, call :func:`~.math.dm_from_state_vector` on the
inputs first, e.g.:

>>> state0 = qml.math.dm_from_state_vector([0.98753537-0.14925137j, 0.00746879-0.04941796j])
>>> state1 = qml.math.dm_from_state_vector([0.99500417+0.j, 0.09983342+0.j])
>>> qml.math.fidelity(state0, state1)
0.9905158135644924

To find the fidelity between two density matrices, they can be passed directly:

>>> state0 = [[1, 0], [0, 0]]
>>> state1 = [[0, 0], [0, 1]]
>>> qml.math.fidelity(state0, state1)
0.0

.. seealso:: :func:`pennylane.math.fidelity_statevector` and :func:`pennylane.qinfo.transforms.fidelity`

"""
# Cast as a c_dtype array
state0 = cast(state0, dtype=c_dtype)
state1 = cast(state1, dtype=c_dtype)

if check_state:
_check_density_matrix(state0)
_check_density_matrix(state1)

if qml.math.shape(state0)[-1] != qml.math.shape(state1)[-1]:
raise qml.QuantumFunctionError("The two states must have the same number of wires.")

# Two mixed states
fid = qml.math.compute_fidelity(state0, state1)
# fid = _compute_fidelity_vanilla(state0, state1)
return fid


def fidelity_statevector(state0, state1, check_state=False, c_dtype="complex128"):
r"""Compute the fidelity for two states (given as state vectors) acting on quantum
systems with the same size.

The fidelity for two pure states given by state vectors :math:`\ket{\psi}` and :math:`\ket{\phi}`
is defined as

.. math::
F( \ket{\psi} , \ket{\phi}) = \left|\braket{\psi, \phi}\right|^2

This is faster than calling :func:`pennylane.math.fidelity` on the density matrix
representation of pure states.

.. note::
It supports all interfaces (Numpy, Autograd, Torch, Tensorflow and Jax). The second state is coerced
to the type and dtype of the first state. The fidelity is returned in the type of the interface of the
first state.

Args:
state0 (tensor_like): ``(2**N)`` or ``(batch_dim, 2**N)`` state vector.
state1 (tensor_like): ``(2**N)`` or ``(batch_dim, 2**N)`` state vector.
check_state (bool): If True, the function will check the validity of both states; that is,
the shape and the norm
c_dtype (str): Complex floating point precision type.

Returns:
float: Fidelity between the two quantum states.

**Example**

Two state vectors can be used as arguments and the fidelity (overlap) is returned, e.g.:

>>> state0 = [0.98753537-0.14925137j, 0.00746879-0.04941796j]
>>> state1 = [0.99500417+0.j, 0.09983342+0.j]
>>> qml.math.fidelity(state0, state1)
0.9905158135644924

.. seealso:: :func:`pennylane.math.fidelity` and :func:`pennylane.qinfo.transforms.fidelity`

"""
# Cast as a c_dtype array
state0 = cast(state0, dtype=c_dtype)
state1 = cast(state1, dtype=c_dtype)

if check_state:
_check_state_vector(state0)
_check_state_vector(state1)

if qml.math.shape(state0)[-1] != qml.math.shape(state1)[-1]:
raise qml.QuantumFunctionError("The two states must have the same number of wires.")

batched0 = len(qml.math.shape(state0)) > 1
batched1 = len(qml.math.shape(state1)) > 1

# Two pure states, squared overlap
indices0 = "ab" if batched0 else "b"
indices1 = "ab" if batched1 else "b"
target = "a" if batched0 or batched1 else ""
overlap = qml.math.einsum(
f"{indices0},{indices1}->{target}", state0, qml.math.conj(state1), optimize="greedy"
)

overlap = qml.math.abs(overlap) ** 2
return overlap


def _compute_fidelity_vanilla(density_matrix0, density_matrix1):
r"""Compute the fidelity for two density matrices with the same number of wires.

.. math::
F( \rho , \sigma ) = -\text{Tr}( \sqrt{\sqrt{\rho} \sigma \sqrt{\rho}})^2
"""
# Implementation in single dispatches (sqrt(rho))
sqrt_mat = qml.math.sqrt_matrix(density_matrix0)

# sqrt(rho) * sigma * sqrt(rho)
sqrt_mat_sqrt = sqrt_mat @ density_matrix1 @ sqrt_mat

# extract eigenvalues
evs = qml.math.eigvalsh(sqrt_mat_sqrt)
evs = qml.math.real(evs)
evs = qml.math.where(evs > 0.0, evs, 1e-15)
eddddddy marked this conversation as resolved.
Show resolved Hide resolved

trace = (qml.math.sum(qml.math.sqrt(evs), -1)) ** 2

return trace


def _compute_fidelity_vjp0(dm0, dm1, grad_out):
"""
Compute the VJP of fidelity with respect to the first density matrix
"""
# sqrt of sigma
sqrt_dm1 = qml.math.sqrt_matrix(dm1)

# eigendecomposition of sqrt(sigma) * rho * sqrt(sigma)
evs0, u0 = qml.math.linalg.eigh(sqrt_dm1 @ dm0 @ sqrt_dm1)
evs0 = qml.math.real(evs0)
evs0 = qml.math.where(evs0 > 0.0, evs0, 1e-15)
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
evs0 = qml.math.cast_like(evs0, sqrt_dm1)

if len(qml.math.shape(dm0)) == 2 and len(qml.math.shape(dm1)) == 2:
u0_dag = qml.math.transpose(qml.math.conj(u0))
grad_dm0 = sqrt_dm1 @ u0 @ (1 / qml.math.sqrt(evs0)[..., None] * u0_dag) @ sqrt_dm1
eddddddy marked this conversation as resolved.
Show resolved Hide resolved

# torch and tensorflow use the Wirtinger derivative which is a different convention
# than the one autograd and jax use for complex differentiation
if qml.math.get_interface(dm0) in ["torch", "tensorflow"]:
grad_dm0 = qml.math.sum(qml.math.sqrt(evs0), -1) * grad_dm0
else:
grad_dm0 = qml.math.sum(qml.math.sqrt(evs0), -1) * qml.math.transpose(grad_dm0)

res = grad_dm0 * qml.math.cast_like(grad_out, grad_dm0)
return res

# broadcasting case
u0_dag = qml.math.transpose(qml.math.conj(u0), (0, 2, 1))
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
grad_dm0 = sqrt_dm1 @ u0 @ (1 / qml.math.sqrt(evs0)[..., None] * u0_dag) @ sqrt_dm1

# torch and tensorflow use the Wirtinger derivative which is a different convention
# than the one autograd and jax use for complex differentiation
if qml.math.get_interface(dm0) in ["torch", "tensorflow"]:
grad_dm0 = qml.math.sum(qml.math.sqrt(evs0), -1)[:, None, None] * grad_dm0
else:
grad_dm0 = qml.math.sum(qml.math.sqrt(evs0), -1)[:, None, None] * qml.math.transpose(
grad_dm0, (0, 2, 1)
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
)

return grad_dm0 * qml.math.cast_like(grad_out, grad_dm0)[:, None, None]


def _compute_fidelity_vjp1(dm0, dm1, grad_out):
eddddddy marked this conversation as resolved.
Show resolved Hide resolved
"""
Compute the VJP of fidelity with respect to the second density matrix
"""
# pylint: disable=arguments-out-of-order
return _compute_fidelity_vjp0(dm1, dm0, grad_out)


def _compute_fidelity_grad(dm0, dm1, grad_out):
return _compute_fidelity_vjp0(dm0, dm1, grad_out), _compute_fidelity_vjp1(dm0, dm1, grad_out)


################################ numpy ###################################

ar.register_function("numpy", "compute_fidelity", _compute_fidelity_vanilla)

################################ autograd ################################


@autograd.extend.primitive
def _compute_fidelity_autograd(dm0, dm1):
return _compute_fidelity_vanilla(dm0, dm1)


def _compute_fidelity_autograd_vjp0(_, dm0, dm1):
timmysilv marked this conversation as resolved.
Show resolved Hide resolved
def vjp(grad_out):
return _compute_fidelity_vjp0(dm0, dm1, grad_out)

return vjp


def _compute_fidelity_autograd_vjp1(_, dm0, dm1):
def vjp(grad_out):
return _compute_fidelity_vjp1(dm0, dm1, grad_out)

return vjp


autograd.extend.defvjp(
_compute_fidelity_autograd, _compute_fidelity_autograd_vjp0, _compute_fidelity_autograd_vjp1
)
ar.register_function("autograd", "compute_fidelity", _compute_fidelity_autograd)

################################# jax #####################################

try:
import jax

@jax.custom_vjp
def _compute_fidelity_jax(dm0, dm1):
return _compute_fidelity_vanilla(dm0, dm1)

def _compute_fidelity_jax_fwd(dm0, dm1):
fid = _compute_fidelity_jax(dm0, dm1)
return fid, (dm0, dm1)

def _compute_fidelity_jax_bwd(res, grad_out):
dm0, dm1 = res
return _compute_fidelity_grad(dm0, dm1, grad_out)

_compute_fidelity_jax.defvjp(_compute_fidelity_jax_fwd, _compute_fidelity_jax_bwd)
ar.register_function("jax", "compute_fidelity", _compute_fidelity_jax)

except ModuleNotFoundError:
# jax not installed
pass

################################ torch ###################################

try:
import torch

class _TorchFidelity(torch.autograd.Function):
@staticmethod
def forward(ctx, dm0, dm1):
"""Forward pass for _compute_fidelity"""
fid = _compute_fidelity_vanilla(dm0, dm1)
ctx.save_for_backward(dm0, dm1)
return fid

@staticmethod
def backward(ctx, grad_out):
"""Backward pass for _compute_fidelity"""
dm0, dm1 = ctx.saved_tensors
return _compute_fidelity_grad(dm0, dm1, grad_out)

ar.register_function("torch", "compute_fidelity", _TorchFidelity.apply)

except ModuleNotFoundError:
# torch not installed
pass

############################### tensorflow ################################

try:
import tensorflow as tf
eddddddy marked this conversation as resolved.
Show resolved Hide resolved

@tf.custom_gradient
def _compute_fidelity_tf(dm0, dm1):
fid = _compute_fidelity_vanilla(dm0, dm1)

def vjp(grad_out):
return _compute_fidelity_grad(dm0, dm1, grad_out)

return fid, vjp

ar.register_function("tensorflow", "compute_fidelity", _compute_fidelity_tf)

except ModuleNotFoundError:
# tensorflow not installed
pass
Loading
Loading