Skip to content

Commit

Permalink
Add gradient for SVD
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Apr 28, 2024
1 parent eb18f0e commit 06f4e4e
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 3 deletions.
115 changes: 112 additions & 3 deletions pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from collections.abc import Callable
from collections.abc import Callable, Sequence
from functools import partial
from typing import Literal
from typing import Literal, cast

import numpy as np
from numpy.core.numeric import normalize_axis_tuple # type: ignore
Expand All @@ -15,7 +15,7 @@
from pytensor.tensor import math as ptm
from pytensor.tensor.basic import as_tensor_variable, diagonal
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector
from pytensor.tensor.type import Variable, dvector, lscalar, matrix, scalar, vector


class MatrixPinv(Op):
Expand Down Expand Up @@ -597,6 +597,115 @@ def infer_shape(self, fgraph, node, shapes):
else:
return [s_shape]

def L_op(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
"""
Reverse-mode gradient of the SVD function. Adapted from the autograd implementation here:
https://github.com/HIPS/autograd/blob/01eacff7a4f12e6f7aebde7c4cb4c1c2633f217d/autograd/numpy/linalg.py#L194
And the mxnet implementation described in ..[1]
References
----------
.. [1] Seeger, Matthias, et al. "Auto-differentiating linear algebra." arXiv preprint arXiv:1710.08717 (2017).
"""

(A,) = (cast(ptb.TensorVariable, x) for x in inputs)

if not self.compute_uv:
# We need all the components of the SVD to compute the gradient of A even if we only use the singular values
# in the cost function.
U, s, VT = svd(A, full_matrices=False, compute_uv=True)
ds = cast(ptb.TensorVariable, output_grads[0])
A_bar = (U.conj() * ds[..., None, :]) @ VT

return [A_bar]

elif self.full_matrices:
raise NotImplementedError(
"Gradient of svd not implemented for full_matrices=True"
)

else:
U, s, VT = (cast(ptb.TensorVariable, x) for x in outputs)

# Handle disconnected inputs
# If a user asked for all the matrices but then only used a subset in the cost function, the unused outputs
# will be DisconnectedType. We replace DisconnectedTypes with zero matrices of the correct shapes.
new_output_grads = []
is_disconnected = [
isinstance(x.type, DisconnectedType) for x in output_grads
]
if all(is_disconnected):
return [DisconnectedType()()]

Check warning on line 644 in pytensor/tensor/nlinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/nlinalg.py#L644

Added line #L644 was not covered by tests
elif is_disconnected == [True, False, True]:
# This is the same as the compute_uv = False, so we can drop back to that simpler computation, without
# needing to re-compoute U and VT
ds = cast(ptb.TensorVariable, output_grads[1])
A_bar = (U.conj() * ds[..., None, :]) @ VT
return [A_bar]

for disconnected, output_grad, output in zip(
is_disconnected, output_grads, [U, s, VT]
):
if disconnected:
new_output_grads.append(output.zeros_like())
else:
new_output_grads.append(output_grad)

(dU, ds, dVT) = (cast(ptb.TensorVariable, x) for x in new_output_grads)

V = VT.T
dV = dVT.T

m, n = A.shape[-2:]

k = ptm.min((m, n))
eye = ptb.eye(k)

def h(t):
"""
Approximation of s_i ** 2 - s_j ** 2, from .. [1].
Robust to identical singular values (singular matrix input), although
gradients are still wrong in this case.
"""
eps = 1e-8

# sign(0) = 0 in pytensor, which defeats the whole purpose of this function
sign_t = ptb.where(ptm.eq(t, 0), 1, ptm.sign(t))
return ptm.maximum(ptm.abs(t), eps) * sign_t

numer = ptb.ones((k, k)) - eye
denom = h(s[None] - s[:, None]) * h(s[None] + s[:, None])
E = numer / denom

utgu = U.T @ dU
vtgv = VT @ dV

A_bar = (E * (utgu - utgu.conj().T)) * s[..., None, :]
A_bar = A_bar + eye * ds[..., :, None]
A_bar = A_bar + s[..., :, None] * (E * (vtgv - vtgv.conj().T))
A_bar = U.conj() @ A_bar @ VT

A_bar = ptb.switch(
ptm.eq(m, n),
A_bar,
ptb.switch(
ptm.lt(m, n),
A_bar
+ (
U / s[..., None, :] @ dVT @ (ptb.eye(n) - V @ V.conj().T)
).conj(),
A_bar
+ (V / s[..., None, :] @ dU.T @ (ptb.eye(m) - U @ U.conj().T)).T,
),
)
return [A_bar]


def svd(a, full_matrices: bool = True, compute_uv: bool = True):
"""
Expand Down
71 changes: 71 additions & 0 deletions tests/tensor/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,77 @@ def validate_shape(self, shape, compute_uv=True, full_matrices=True):
outputs = [outputs]
self._compile_and_check([A], outputs, [A_v], self.op_class, warn=False)

@pytest.mark.parametrize(
"compute_uv, full_matrices, gradient_test_case",
[(False, False, 0)]
+ [(True, False, i) for i in range(7)]
+ [(True, True, i) for i in range(7)],
ids=(
["compute_uv=False, full_matrices=False"]
+ [
f"compute_uv=True, full_matrices=False, gradient={grad}"
for grad in ["U", "s", "V", "U+s", "s+V", "U+V", "U+s+V"]
]
+ [
f"compute_uv=True, full_matrices=True, gradient={grad}"
for grad in ["U", "s", "V", "U+s", "s+V", "U+V", "U+s+V"]
]
),
)
@pytest.mark.parametrize(
"shape", [(3, 3), (4, 3), (3, 4)], ids=["(3,3)", "(4,3)", "(3,4)"]
)
@pytest.mark.parametrize(
"batched", [True, False], ids=["batched=True", "batched=False"]
)
def test_grad(self, compute_uv, full_matrices, gradient_test_case, shape, batched):
rng = np.random.default_rng(utt.fetch_seed())
if batched:
shape = (4, *shape)

A_v = self.rng.normal(size=shape).astype(config.floatX)
if full_matrices:
with pytest.raises(
NotImplementedError,
match="Gradient of svd not implemented for full_matrices=True",
):
U, s, V = svd(
self.A, compute_uv=compute_uv, full_matrices=full_matrices
)
pytensor.grad(s.sum(), self.A)

elif compute_uv:

def svd_fn(A, case=0):
U, s, V = svd(A, compute_uv=compute_uv, full_matrices=full_matrices)
if case == 0:
return U.sum()
elif case == 1:
return s.sum()
elif case == 2:
return V.sum()
elif case == 3:
return U.sum() + s.sum()
elif case == 4:
return s.sum() + V.sum()
elif case == 5:
return U.sum() + V.sum()
elif case == 6:
return U.sum() + s.sum() + V.sum()

utt.verify_grad(
partial(svd_fn, case=gradient_test_case),
[A_v],
rng=rng,
)

else:
utt.verify_grad(
partial(svd, compute_uv=compute_uv, full_matrices=full_matrices),
[A_v],
rng=rng,
)


def test_tensorsolve():
rng = np.random.default_rng(utt.fetch_seed())
Expand Down

0 comments on commit 06f4e4e

Please sign in to comment.