Skip to content

Commit

Permalink
Add gradient for SVD
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Apr 28, 2024
1 parent eb18f0e commit 99a703c
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 3 deletions.
118 changes: 115 additions & 3 deletions pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from collections.abc import Callable
from collections.abc import Callable, Sequence
from functools import partial
from typing import Literal
from typing import Literal, cast

import numpy as np
from numpy.core.numeric import normalize_axis_tuple # type: ignore
Expand All @@ -15,7 +15,7 @@
from pytensor.tensor import math as ptm
from pytensor.tensor.basic import as_tensor_variable, diagonal
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector
from pytensor.tensor.type import Variable, dvector, lscalar, matrix, scalar, vector


class MatrixPinv(Op):
Expand Down Expand Up @@ -597,6 +597,118 @@ def infer_shape(self, fgraph, node, shapes):
else:
return [s_shape]

def L_op(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
"""
Reverse-mode gradient of the SVD function. Adapted from the autograd implementation here:
https://github.com/HIPS/autograd/blob/01eacff7a4f12e6f7aebde7c4cb4c1c2633f217d/autograd/numpy/linalg.py#L194
And the mxnet implementation described in ..[1]
References
----------
.. [1] Seeger, Matthias, et al. "Auto-differentiating linear algebra." arXiv preprint arXiv:1710.08717 (2017).
"""

def s_grad_only(
U: ptb.TensorVariable, VT: ptb.TensorVariable, ds: ptb.TensorVariable
) -> list[Variable]:
A_bar = (U.conj() * ds[..., None, :]) @ VT
return [A_bar]

(A,) = (cast(ptb.TensorVariable, x) for x in inputs)

if not self.compute_uv:
# We need all the components of the SVD to compute the gradient of A even if we only use the singular values
# in the cost function.
U, _, VT = svd(A, full_matrices=False, compute_uv=True)
ds = cast(ptb.TensorVariable, output_grads[0])
return s_grad_only(U, VT, ds)

elif self.full_matrices:
raise NotImplementedError(
"Gradient of svd not implemented for full_matrices=True"
)

else:
U, s, VT = (cast(ptb.TensorVariable, x) for x in outputs)

# Handle disconnected inputs
# If a user asked for all the matrices but then only used a subset in the cost function, the unused outputs
# will be DisconnectedType. We replace DisconnectedTypes with zero matrices of the correct shapes.
new_output_grads = []
is_disconnected = [
isinstance(x.type, DisconnectedType) for x in output_grads
]
if all(is_disconnected):
return [DisconnectedType()()]

Check warning on line 648 in pytensor/tensor/nlinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/nlinalg.py#L648

Added line #L648 was not covered by tests
elif is_disconnected == [True, False, True]:
# This is the same as the compute_uv = False, so we can drop back to that simpler computation, without
# needing to re-compoute U and VT
ds = cast(ptb.TensorVariable, output_grads[1])
return s_grad_only(U, VT, ds)

for disconnected, output_grad, output in zip(
is_disconnected, output_grads, [U, s, VT]
):
if disconnected:
new_output_grads.append(output.zeros_like())
else:
new_output_grads.append(output_grad)

(dU, ds, dVT) = (cast(ptb.TensorVariable, x) for x in new_output_grads)

V = VT.T
dV = dVT.T

m, n = A.shape[-2:]

k = ptm.min((m, n))
eye = ptb.eye(k)

def h(t):
"""
Approximation of s_i ** 2 - s_j ** 2, from .. [1].
Robust to identical singular values (singular matrix input), although
gradients are still wrong in this case.
"""
eps = 1e-8

# sign(0) = 0 in pytensor, which defeats the whole purpose of this function
sign_t = ptb.where(ptm.eq(t, 0), 1, ptm.sign(t))
return ptm.maximum(ptm.abs(t), eps) * sign_t

numer = ptb.ones((k, k)) - eye
denom = h(s[None] - s[:, None]) * h(s[None] + s[:, None])
E = numer / denom

utgu = U.T @ dU
vtgv = VT @ dV

A_bar = (E * (utgu - utgu.conj().T)) * s[..., None, :]
A_bar = A_bar + eye * ds[..., :, None]
A_bar = A_bar + s[..., :, None] * (E * (vtgv - vtgv.conj().T))
A_bar = U.conj() @ A_bar @ VT

A_bar = ptb.switch(
ptm.eq(m, n),
A_bar,
ptb.switch(
ptm.lt(m, n),
A_bar
+ (
U / s[..., None, :] @ dVT @ (ptb.eye(n) - V @ V.conj().T)
).conj(),
A_bar
+ (V / s[..., None, :] @ dU.T @ (ptb.eye(m) - U @ U.conj().T)).T,
),
)
return [A_bar]


def svd(a, full_matrices: bool = True, compute_uv: bool = True):
"""
Expand Down
75 changes: 75 additions & 0 deletions tests/tensor/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytensor
from pytensor import function
from pytensor.configdefaults import config
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.math import _allclose
from pytensor.tensor.nlinalg import (
SVD,
Expand Down Expand Up @@ -215,6 +216,80 @@ def validate_shape(self, shape, compute_uv=True, full_matrices=True):
outputs = [outputs]
self._compile_and_check([A], outputs, [A_v], self.op_class, warn=False)

@pytest.mark.parametrize(
"compute_uv, full_matrices, gradient_test_case",
[(False, False, 0)]
+ [(True, False, i) for i in range(8)]
+ [(True, True, i) for i in range(8)],
ids=(
["compute_uv=False, full_matrices=False"]
+ [
f"compute_uv=True, full_matrices=False, gradient={grad}"
for grad in ["U", "s", "V", "U+s", "s+V", "U+V", "U+s+V", "None"]
]
+ [
f"compute_uv=True, full_matrices=True, gradient={grad}"
for grad in ["U", "s", "V", "U+s", "s+V", "U+V", "U+s+V", "None"]
]
),
)
@pytest.mark.parametrize(
"shape", [(3, 3), (4, 3), (3, 4)], ids=["(3,3)", "(4,3)", "(3,4)"]
)
@pytest.mark.parametrize(
"batched", [True, False], ids=["batched=True", "batched=False"]
)
def test_grad(self, compute_uv, full_matrices, gradient_test_case, shape, batched):
rng = np.random.default_rng(utt.fetch_seed())
if batched:
shape = (4, *shape)

A_v = self.rng.normal(size=shape).astype(config.floatX)
if full_matrices:
with pytest.raises(
NotImplementedError,
match="Gradient of svd not implemented for full_matrices=True",
):
U, s, V = svd(
self.A, compute_uv=compute_uv, full_matrices=full_matrices
)
pytensor.grad(s.sum(), self.A)

elif compute_uv:

def svd_fn(A, case=0):
U, s, V = svd(A, compute_uv=compute_uv, full_matrices=full_matrices)
if case == 0:
return U.sum()
elif case == 1:
return s.sum()
elif case == 2:
return V.sum()
elif case == 3:
return U.sum() + s.sum()
elif case == 4:
return s.sum() + V.sum()
elif case == 5:
return U.sum() + V.sum()
elif case == 6:
return U.sum() + s.sum() + V.sum()
elif case == 7:
# All inputs disconnected
return as_tensor_variable(3.0)

utt.verify_grad(
partial(svd_fn, case=gradient_test_case),
[A_v],
rng=rng,
)

else:
utt.verify_grad(
partial(svd, compute_uv=compute_uv, full_matrices=full_matrices),
[A_v],
rng=rng,
)


def test_tensorsolve():
rng = np.random.default_rng(utt.fetch_seed())
Expand Down

0 comments on commit 99a703c

Please sign in to comment.