Skip to content

Commit

Permalink
Add gradient for SVD
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Feb 14, 2024
1 parent 453fb4d commit b6c79fd
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 2 deletions.
88 changes: 86 additions & 2 deletions pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from collections.abc import Sequence
from functools import partial
from typing import Callable, Literal, Optional, Union
from typing import Callable, Literal, Optional, Union, cast

import numpy as np
from numpy.core.numeric import normalize_axis_tuple # type: ignore
Expand All @@ -13,7 +14,7 @@
from pytensor.tensor import math as ptm
from pytensor.tensor.basic import as_tensor_variable, diagonal
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector
from pytensor.tensor.type import Variable, dvector, lscalar, matrix, scalar, vector


class MatrixPinv(Op):
Expand Down Expand Up @@ -595,6 +596,89 @@ def infer_shape(self, fgraph, node, shapes):
else:
return [s_shape]

def L_op(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
"""
Reverse-mode gradient of the SVD function. Adapted from the autograd implementation here:
https://github.com/HIPS/autograd/blob/01eacff7a4f12e6f7aebde7c4cb4c1c2633f217d/autograd/numpy/linalg.py#L194
And the mxnet implementation described in ..[1]
References
----------
.. [1] Seeger, Matthias, et al. "Auto-differentiating linear algebra." arXiv preprint arXiv:1710.08717 (2017).
"""
(A,) = (cast(ptb.TensorVariable, x) for x in inputs)

if not self.compute_uv:
# We need all the components of the SVD to compute the gradient of A even if we only use the singular values
# in the cost function.
U, s, VT = svd(A, full_matrices=False, compute_uv=True)

(ds,) = (cast(ptb.TensorVariable, x) for x in output_grads)
A_bar = (U.conj() * ds[..., None, :]) @ VT

return [A_bar]

elif self.full_matrices:
raise NotImplementedError(
"Gradient of svd not implemented for full_matrices=True"
)

else:
U, s, VT = (cast(ptb.TensorVariable, x) for x in outputs)
(dU, ds, dVT) = (cast(ptb.TensorVariable, x) for x in output_grads)
V = VT.T
dV = dVT.T

m, n = A.shape[-2:]

k = ptm.min((m, n))
eye = ptb.eye(k)

def h(t):
"""
Approximation of s_i ** 2 - s_j ** 2, from .. [1].
Robust to identical singular values (singular matrix input), although
gradients are still wrong in this case.
"""
eps = 1e-8

# sign(0) = 0 in pytensor, which defeats the whole purpose of this function
sign_t = ptb.where(ptm.eq(t, 0), 1, ptm.sign(t))
return ptm.maximum(ptm.abs(t), eps) * sign_t

numer = ptb.ones((k, k)) - eye
denom = h(s[None] - s[:, None]) * h(s[None] + s[:, None])
E = numer / denom

utgu = U.T @ dU
vtgv = VT @ dV

A_bar = (E * (utgu - utgu.conj().T)) * s[..., None, :]
A_bar = A_bar + eye * ds[..., :, None]
A_bar = A_bar + s[..., :, None] * (E * (vtgv - vtgv.conj().T))
A_bar = U.conj() @ A_bar @ VT

A_bar = ptb.switch(
ptm.eq(m, n),
A_bar,
ptb.switch(
ptm.lt(m, n),
A_bar
+ (
U / s[..., None, :] @ dVT @ (ptb.eye(n) - V @ V.conj().T)
).conj(),
A_bar
+ (V / s[..., None, :] @ dU.T @ (ptb.eye(m) - U @ U.conj().T)).T,
),
)
return [A_bar]


def svd(a, full_matrices: bool = True, compute_uv: bool = True):
"""
Expand Down
45 changes: 45 additions & 0 deletions tests/tensor/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,51 @@ def validate_shape(self, shape, compute_uv=True, full_matrices=True):
outputs = [outputs]
self._compile_and_check([A], outputs, [A_v], self.op_class, warn=False)

@pytest.mark.parametrize(
"compute_uv, full_matrices",
[(True, False), (False, False), (True, True)],
ids=[
"compute_uv=True, full_matrices=False",
"compute_uv=False, full_matrices=False",
"compute_uv=True, full_matrices=True",
],
)
@pytest.mark.parametrize(
"shape", [(3, 3), (4, 3), (3, 4)], ids=["(3,3)", "(4,3)", "(3,4)"]
)
@pytest.mark.parametrize(
"batched", [True, False], ids=["batched=True", "batched=False"]
)
def test_grad(self, compute_uv, full_matrices, shape, batched):
rng = np.random.default_rng(utt.fetch_seed())
if batched:
shape = (4, *shape)

A_v = self.rng.normal(size=shape).astype(config.floatX)
if full_matrices:
with pytest.raises(
NotImplementedError,
match="Gradient of svd not implemented for full_matrices=True",
):
U, s, V = svd(
self.A, compute_uv=compute_uv, full_matrices=full_matrices
)
pytensor.grad(s.sum(), self.A)

elif compute_uv:
utt.verify_grad(
partial(svd, compute_uv=compute_uv, full_matrices=full_matrices),
[A_v],
rng=rng,
)

else:
utt.verify_grad(
partial(svd, compute_uv=compute_uv, full_matrices=full_matrices),
[A_v],
rng=rng,
)


def test_tensorsolve():
rng = np.random.default_rng(utt.fetch_seed())
Expand Down

0 comments on commit b6c79fd

Please sign in to comment.