From 06f4e4e084f66a74aabca3c98b8541c0bc96d0f6 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 13 Feb 2024 22:43:09 +0800 Subject: [PATCH] Add gradient for `SVD` --- pytensor/tensor/nlinalg.py | 115 ++++++++++++++++++++++++++++++++++- tests/tensor/test_nlinalg.py | 71 +++++++++++++++++++++ 2 files changed, 183 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 85215fbe06..836a51acd7 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -1,7 +1,7 @@ import warnings -from collections.abc import Callable +from collections.abc import Callable, Sequence from functools import partial -from typing import Literal +from typing import Literal, cast import numpy as np from numpy.core.numeric import normalize_axis_tuple # type: ignore @@ -15,7 +15,7 @@ from pytensor.tensor import math as ptm from pytensor.tensor.basic import as_tensor_variable, diagonal from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector +from pytensor.tensor.type import Variable, dvector, lscalar, matrix, scalar, vector class MatrixPinv(Op): @@ -597,6 +597,115 @@ def infer_shape(self, fgraph, node, shapes): else: return [s_shape] + def L_op( + self, + inputs: Sequence[Variable], + outputs: Sequence[Variable], + output_grads: Sequence[Variable], + ) -> list[Variable]: + """ + Reverse-mode gradient of the SVD function. Adapted from the autograd implementation here: + https://github.com/HIPS/autograd/blob/01eacff7a4f12e6f7aebde7c4cb4c1c2633f217d/autograd/numpy/linalg.py#L194 + + And the mxnet implementation described in ..[1] + + References + ---------- + .. [1] Seeger, Matthias, et al. "Auto-differentiating linear algebra." arXiv preprint arXiv:1710.08717 (2017). + """ + + (A,) = (cast(ptb.TensorVariable, x) for x in inputs) + + if not self.compute_uv: + # We need all the components of the SVD to compute the gradient of A even if we only use the singular values + # in the cost function. + U, s, VT = svd(A, full_matrices=False, compute_uv=True) + ds = cast(ptb.TensorVariable, output_grads[0]) + A_bar = (U.conj() * ds[..., None, :]) @ VT + + return [A_bar] + + elif self.full_matrices: + raise NotImplementedError( + "Gradient of svd not implemented for full_matrices=True" + ) + + else: + U, s, VT = (cast(ptb.TensorVariable, x) for x in outputs) + + # Handle disconnected inputs + # If a user asked for all the matrices but then only used a subset in the cost function, the unused outputs + # will be DisconnectedType. We replace DisconnectedTypes with zero matrices of the correct shapes. + new_output_grads = [] + is_disconnected = [ + isinstance(x.type, DisconnectedType) for x in output_grads + ] + if all(is_disconnected): + return [DisconnectedType()()] + elif is_disconnected == [True, False, True]: + # This is the same as the compute_uv = False, so we can drop back to that simpler computation, without + # needing to re-compoute U and VT + ds = cast(ptb.TensorVariable, output_grads[1]) + A_bar = (U.conj() * ds[..., None, :]) @ VT + return [A_bar] + + for disconnected, output_grad, output in zip( + is_disconnected, output_grads, [U, s, VT] + ): + if disconnected: + new_output_grads.append(output.zeros_like()) + else: + new_output_grads.append(output_grad) + + (dU, ds, dVT) = (cast(ptb.TensorVariable, x) for x in new_output_grads) + + V = VT.T + dV = dVT.T + + m, n = A.shape[-2:] + + k = ptm.min((m, n)) + eye = ptb.eye(k) + + def h(t): + """ + Approximation of s_i ** 2 - s_j ** 2, from .. [1]. + Robust to identical singular values (singular matrix input), although + gradients are still wrong in this case. + """ + eps = 1e-8 + + # sign(0) = 0 in pytensor, which defeats the whole purpose of this function + sign_t = ptb.where(ptm.eq(t, 0), 1, ptm.sign(t)) + return ptm.maximum(ptm.abs(t), eps) * sign_t + + numer = ptb.ones((k, k)) - eye + denom = h(s[None] - s[:, None]) * h(s[None] + s[:, None]) + E = numer / denom + + utgu = U.T @ dU + vtgv = VT @ dV + + A_bar = (E * (utgu - utgu.conj().T)) * s[..., None, :] + A_bar = A_bar + eye * ds[..., :, None] + A_bar = A_bar + s[..., :, None] * (E * (vtgv - vtgv.conj().T)) + A_bar = U.conj() @ A_bar @ VT + + A_bar = ptb.switch( + ptm.eq(m, n), + A_bar, + ptb.switch( + ptm.lt(m, n), + A_bar + + ( + U / s[..., None, :] @ dVT @ (ptb.eye(n) - V @ V.conj().T) + ).conj(), + A_bar + + (V / s[..., None, :] @ dU.T @ (ptb.eye(m) - U @ U.conj().T)).T, + ), + ) + return [A_bar] + def svd(a, full_matrices: bool = True, compute_uv: bool = True): """ diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index 4bb88c9bc4..f3145434f2 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -215,6 +215,77 @@ def validate_shape(self, shape, compute_uv=True, full_matrices=True): outputs = [outputs] self._compile_and_check([A], outputs, [A_v], self.op_class, warn=False) + @pytest.mark.parametrize( + "compute_uv, full_matrices, gradient_test_case", + [(False, False, 0)] + + [(True, False, i) for i in range(7)] + + [(True, True, i) for i in range(7)], + ids=( + ["compute_uv=False, full_matrices=False"] + + [ + f"compute_uv=True, full_matrices=False, gradient={grad}" + for grad in ["U", "s", "V", "U+s", "s+V", "U+V", "U+s+V"] + ] + + [ + f"compute_uv=True, full_matrices=True, gradient={grad}" + for grad in ["U", "s", "V", "U+s", "s+V", "U+V", "U+s+V"] + ] + ), + ) + @pytest.mark.parametrize( + "shape", [(3, 3), (4, 3), (3, 4)], ids=["(3,3)", "(4,3)", "(3,4)"] + ) + @pytest.mark.parametrize( + "batched", [True, False], ids=["batched=True", "batched=False"] + ) + def test_grad(self, compute_uv, full_matrices, gradient_test_case, shape, batched): + rng = np.random.default_rng(utt.fetch_seed()) + if batched: + shape = (4, *shape) + + A_v = self.rng.normal(size=shape).astype(config.floatX) + if full_matrices: + with pytest.raises( + NotImplementedError, + match="Gradient of svd not implemented for full_matrices=True", + ): + U, s, V = svd( + self.A, compute_uv=compute_uv, full_matrices=full_matrices + ) + pytensor.grad(s.sum(), self.A) + + elif compute_uv: + + def svd_fn(A, case=0): + U, s, V = svd(A, compute_uv=compute_uv, full_matrices=full_matrices) + if case == 0: + return U.sum() + elif case == 1: + return s.sum() + elif case == 2: + return V.sum() + elif case == 3: + return U.sum() + s.sum() + elif case == 4: + return s.sum() + V.sum() + elif case == 5: + return U.sum() + V.sum() + elif case == 6: + return U.sum() + s.sum() + V.sum() + + utt.verify_grad( + partial(svd_fn, case=gradient_test_case), + [A_v], + rng=rng, + ) + + else: + utt.verify_grad( + partial(svd, compute_uv=compute_uv, full_matrices=full_matrices), + [A_v], + rng=rng, + ) + def test_tensorsolve(): rng = np.random.default_rng(utt.fetch_seed())