Skip to content

Commit

Permalink
Add gradient for svd
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Jan 23, 2024
1 parent b63ee0c commit 207f55e
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 9 deletions.
82 changes: 73 additions & 9 deletions pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import warnings
from collections.abc import Sequence
from functools import partial
from typing import Union, cast

import numpy as np

from pytensor import Variable
from pytensor import scalar as ps
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply
Expand Down Expand Up @@ -582,26 +585,87 @@ def infer_shape(self, fgraph, node, shapes):
else:
return [s_shape]

def L_op(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
(A,) = inputs
A = cast(ptb.TensorVariable, A)

def svd(a, full_matrices: bool = True, compute_uv: bool = True):
if not self.compute_uv:
(S_grad,) = output_grads
S_grad = cast(ptb.TensorVariable, S_grad)

# Need U and V so do the whole svd anyway...
[u, s, v] = svd(A, full_matrices=False, compute_uv=True) # type: ignore
u = cast(ptb.TensorVariable, u)

return [(u.conj() * S_grad[..., None, :]) @ v]

elif self.full_matrices:
raise NotImplementedError(
"Gradient of svd not implemented for full_matrices=True"
)

else:
u, s, v = (cast(ptb.TensorVariable, x) for x in outputs)
gu, gs, gv = (cast(ptb.TensorVariable, x) for x in output_grads)

m, n = A.shape[-2:]

k = ptm.min((m, n))
# broadcastable identity array with shape (1, 1, ..., 1, k, k)
# i = anp.reshape(anp.eye(k), anp.concatenate((anp.ones(a.ndim - 2, dtype=int), (k, k))))

eye = ptb.eye(k)
f = 1 / (s[..., None, :] ** 2 - s[..., :, None] ** 2 + eye)

utgu = u.T @ gu
vtgv = v.T @ gv
t1 = f * (utgu - utgu.conj().T * s[..., None, :])
t1 = t1 + eye * gs[..., :, None]
t1 = t1 + s[..., :, None] * (f * (vtgv - vtgv.conj().T))

if u.dtype.startswith("complex"):
t1 = t1 + 1j * ptb.diag(utgu.imag) / s[..., None, :]

t1 = u.conj() @ t1 @ v.T
t1 = cast(ptb.TensorVariable, t1)

if m < n:
eye_n = ptb.eye(n)
i_minus_vtt = eye_n - (v @ v.conj().T)
t1 = t1 + (u / s[..., None, :] @ gv.T @ i_minus_vtt).conj()

elif m > n:
eye_m = ptb.eye(n)
i_minus_uut = eye_m - u @ u.conj().T
t1 = t1 + v / s[..., None, :] @ gu.T @ i_minus_uut

return [t1]


def svd(
a, full_matrices: bool = True, compute_uv: bool = True
) -> Union[Variable, list[Variable]]:
"""
This function performs the SVD on CPU.
Parameters
----------
full_matrices : bool, optional
If True (default), u and v have the shapes (M, M) and (N, N),
respectively.
Otherwise, the shapes are (M, K) and (K, N), respectively,
where K = min(M, N).
If True (default), u and v have the shapes (M, M) and (N, N), respectively. Otherwise, the shapes are (M, K)
and (K, N), respectively, where K = min(M, N).
compute_uv : bool, optional
Whether or not to compute u and v in addition to s.
True by default.
Whether or not to compute u and v in addition to s. True by default.
Returns
-------
U, V, D : matrices
matrices: TensorVariable or list of TensorVariable
Result of singular value decomposition. If compute_uv is True, return a list of TensorVariable [U, S, V].
Otherwise, returns only singular values S.
"""
return SVD(full_matrices, compute_uv)(a)

Expand Down
34 changes: 34 additions & 0 deletions tests/tensor/test_nlinalg.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import functools as ft

import numpy as np
import numpy.linalg
import pytest
Expand Down Expand Up @@ -189,6 +191,38 @@ def validate_shape(self, shape, compute_uv=True, full_matrices=True):
outputs = [outputs]
self._compile_and_check([A], outputs, [A_v], self.op_class, warn=False)

@pytest.mark.parametrize(
"compute_uv, full_matrices",
[(True, False), (False, False), (True, True)],
ids=[
"compute_uv=True, full_matrices=False",
"compute_uv=False, full_matrices=False",
"compute_uv=True, full_matrices=True",
],
)
def test_grad(self, compute_uv, full_matrices):
A_v = self.rng.random((4, 4)).astype(self.dtype)
if full_matrices:
with pytest.raises(
NotImplementedError,
match="Gradient of svd not implemented for full_matrices=True",
):
u, s, v = svd(
self.A, compute_uv=compute_uv, full_matrices=full_matrices
)
pytensor.grad(s.sum(), self.A)
elif compute_uv:
# u, s, v = svd(self.A, compute_uv=compute_uv, full_matrices=full_matrices)
# op = pytensor.compile.builders.OpFromGraph([self.A], [s])
# utt.verify_grad(op,[A_v], rng=np.random)
pytest.mark.skip("Gradients of function with multiple outputs not testable")
else:
utt.verify_grad(
ft.partial(svd, compute_uv=compute_uv, full_matrices=full_matrices),
[A_v],
rng=np.random,
)


def test_tensorsolve():
rng = np.random.default_rng(utt.fetch_seed())
Expand Down

0 comments on commit 207f55e

Please sign in to comment.