From 29fd4a479bf511027fe2245e3a7a6fec8d4837f6 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Thu, 29 Sep 2022 20:03:27 +0200 Subject: [PATCH] Add solve_discrete_lyapunov and solve_continuous_lyapunov --- aesara/tensor/slinalg.py | 161 ++++++++++++++++++++++++++++++++++- tests/tensor/test_slinalg.py | 68 ++++++++++++++- 2 files changed, 226 insertions(+), 3 deletions(-) diff --git a/aesara/tensor/slinalg.py b/aesara/tensor/slinalg.py index 0c84cd36ed..f8b265d8d6 100644 --- a/aesara/tensor/slinalg.py +++ b/aesara/tensor/slinalg.py @@ -1,9 +1,10 @@ import logging import warnings -from typing import Union +from typing import TYPE_CHECKING, Union import numpy as np import scipy.linalg +from typing_extensions import Literal import aesara.tensor from aesara.graph.basic import Apply @@ -11,10 +12,15 @@ from aesara.tensor import as_tensor_variable from aesara.tensor import basic as at from aesara.tensor import math as atm +from aesara.tensor.shape import reshape from aesara.tensor.type import matrix, tensor, vector from aesara.tensor.var import TensorVariable +if TYPE_CHECKING: + from aesara.tensor import TensorLike + + logger = logging.getLogger(__name__) @@ -668,6 +674,159 @@ def perform(self, node, inputs, outputs): expm = Expm() + +class SolveContinuousLyapunov(Op): + __props__ = () + + def make_node(self, A, B): + A = as_tensor_variable(A) + B = as_tensor_variable(B) + + out_dtype = aesara.scalar.upcast(A.dtype, B.dtype) + X = aesara.tensor.matrix(dtype=out_dtype) + + return aesara.graph.basic.Apply(self, [A, B], [X]) + + def perform(self, node, inputs, output_storage): + (A, B) = inputs + X = output_storage[0] + + X[0] = scipy.linalg.solve_continuous_lyapunov(A, B) + + def infer_shape(self, fgraph, node, shapes): + return [shapes[0]] + + def grad(self, inputs, output_grads): + # Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf + # Note that they write the equation as AX + XA.H + Q = 0, while scipy uses AX + XA^H = Q, + # so minor adjustments need to be made. + A, Q = inputs + (dX,) = output_grads + + X = self(A, Q) + S = self(A.conj().T, -dX) # Eq 31, adjusted + + A_bar = S.dot(X.conj().T) + S.conj().T.dot(X) + Q_bar = -S # Eq 29, adjusted + + return [A_bar, Q_bar] + + +class BilinearSolveDiscreteLyapunov(Op): + def make_node(self, A, B): + A = as_tensor_variable(A) + B = as_tensor_variable(B) + + out_dtype = aesara.scalar.upcast(A.dtype, B.dtype) + X = aesara.tensor.matrix(dtype=out_dtype) + + return aesara.graph.basic.Apply(self, [A, B], [X]) + + def perform(self, node, inputs, output_storage): + (A, B) = inputs + X = output_storage[0] + + X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear") + + def infer_shape(self, fgraph, node, shapes): + return [shapes[0]] + + def grad(self, inputs, output_grads): + # Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf + A, Q = inputs + (dX,) = output_grads + + X = self(A, Q) + + # Eq 41, note that it is not written as a proper Lyapunov equation + S = self(A.conj().T, dX) + + A_bar = aesara.tensor.linalg.matrix_dot( + S, A, X.conj().T + ) + aesara.tensor.linalg.matrix_dot(S.conj().T, A, X) + Q_bar = S + return [A_bar, Q_bar] + + +_solve_continuous_lyapunov = SolveContinuousLyapunov() +_solve_bilinear_direct_lyapunov = BilinearSolveDiscreteLyapunov() + + +def iscomplexobj(x): + type_ = x.type + dtype = type_.dtype + return "complex" in dtype + + +def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable: + A_ = as_tensor_variable(A) + Q_ = as_tensor_variable(Q) + + if "complex" in A_.type.dtype: + AA = kron(A_, A_.conj()) + else: + AA = kron(A_, A_) + + X = solve(at.eye(AA.shape[0]) - AA, Q_.ravel()) + return reshape(X, Q_.shape) + + +def solve_discrete_lyapunov( + A: "TensorLike", Q: "TensorLike", method: Literal["direct", "bilinear"] = "direct" +) -> TensorVariable: + """Solve the discrete Lyapunov equation :math:`A X A^H - X = Q`. + + Parameters + ---------- + A + Square matrix of shape N x N; must have the same shape as Q + Q + Square matrix of shape N x N; must have the same shape as A + method + Solver method used, one of ``"direct"`` or ``"bilinear"``. ``"direct"`` + solves the problem directly via matrix inversion. This has a pure + Aesara implementation and can thus be cross-compiled to supported + backends, and should be preferred when ``N`` is not large. The direct + method scales poorly with the size of ``N``, and the bilinear can be + used in these cases. + + Returns + ------- + Square matrix of shape ``N x N``, representing the solution to the + Lyapunov equation + + """ + if method not in ["direct", "bilinear"]: + raise ValueError( + f'Parameter "method" must be one of "direct" or "bilinear", found {method}' + ) + + if method == "direct": + return _direct_solve_discrete_lyapunov(A, Q) + if method == "bilinear": + return _solve_bilinear_direct_lyapunov(A, Q) + + +def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable: + """Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`. + + Parameters + ---------- + A + Square matrix of shape ``N x N``; must have the same shape as `Q`. + Q + Square matrix of shape ``N x N``; must have the same shape as `A`. + + Returns + ------- + Square matrix of shape ``N x N``, representing the solution to the + Lyapunov equation + + """ + + return _solve_continuous_lyapunov(A, Q) + + __all__ = [ "cholesky", "solve", diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 52b70c2806..073766365e 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -2,7 +2,6 @@ import itertools import numpy as np -import numpy.linalg import pytest import scipy @@ -22,6 +21,8 @@ expm, kron, solve, + solve_continuous_lyapunov, + solve_discrete_lyapunov, solve_triangular, ) from aesara.tensor.type import dmatrix, matrix, tensor, vector @@ -508,7 +509,6 @@ def test_expm_grad_3(): class TestKron(utt.InferShapeTester): - rng = np.random.default_rng(43) def setup_method(self): @@ -546,3 +546,67 @@ def test_numpy_2d(self): b = self.rng.random(shp1).astype(config.floatX) out = f(a, b) assert np.allclose(out, np.kron(a, b)) + + +def test_solve_discrete_lyapunov_via_direct_real(): + N = 5 + rng = np.random.default_rng(utt.fetch_seed()) + a = at.dmatrix() + q = at.dmatrix() + f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")]) + + A = rng.normal(size=(N, N)) + Q = rng.normal(size=(N, N)) + + X = f(A, Q) + assert np.allclose(A @ X @ A.T - X + Q, 0.0) + + utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng) + + +def test_solve_discrete_lyapunov_via_direct_complex(): + N = 5 + rng = np.random.default_rng(utt.fetch_seed()) + a = at.zmatrix() + q = at.zmatrix() + f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")]) + + A = rng.normal(size=(N, N)) + rng.normal(size=(N, N)) * 1j + Q = rng.normal(size=(N, N)) + X = f(A, Q) + assert np.allclose(A @ X @ A.conj().T - X + Q, 0.0) + + # TODO: the .conj() method currently does not have a gradient; add this test when gradients are implemented. + # utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng) + + +def test_solve_discrete_lyapunov_via_bilinear(): + N = 5 + rng = np.random.default_rng(utt.fetch_seed()) + a = at.dmatrix() + q = at.dmatrix() + f = function([a, q], [solve_discrete_lyapunov(a, q, method="bilinear")]) + + A = rng.normal(size=(N, N)) + Q = rng.normal(size=(N, N)) + + X = f(A, Q) + assert np.allclose(A @ X @ A.conj().T - X + Q, 0.0) + + utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng) + + +def test_solve_continuous_lyapunov(): + N = 5 + rng = np.random.default_rng(utt.fetch_seed()) + a = at.dmatrix() + q = at.dmatrix() + f = function([a, q], [solve_continuous_lyapunov(a, q)]) + + A = rng.normal(size=(N, N)) + Q = rng.normal(size=(N, N)) + X = f(A, Q) + + assert np.allclose(A @ X + X @ A.conj().T, Q) + + utt.verify_grad(solve_continuous_lyapunov, pt=[A, Q], rng=rng)