Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Op corresponding to scipy.linalg.solve_discrete_are #417

Merged
merged 5 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 95 additions & 16 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import typing
import warnings
from typing import TYPE_CHECKING, Literal, Union

Expand All @@ -12,6 +13,7 @@
from pytensor.tensor import as_tensor_variable
from pytensor.tensor import basic as at
from pytensor.tensor import math as atm
from pytensor.tensor.nlinalg import matrix_dot
from pytensor.tensor.shape import reshape
from pytensor.tensor.type import matrix, tensor, vector
from pytensor.tensor.var import TensorVariable
Expand Down Expand Up @@ -321,9 +323,6 @@ def L_op(self, inputs, outputs, output_gradients):
return res


solvetriangular = SolveTriangular()


def solve_triangular(
a: TensorVariable,
b: TensorVariable,
Expand Down Expand Up @@ -397,9 +396,6 @@ def perform(self, node, inputs, outputs):
)


solve = Solve()


def solve(a, b, assume_a="gen", lower=False, check_finite=True):
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.

Expand Down Expand Up @@ -748,13 +744,9 @@ def grad(self, inputs, output_grads):


_solve_continuous_lyapunov = SolveContinuousLyapunov()
_solve_bilinear_direct_lyapunov = BilinearSolveDiscreteLyapunov()


def iscomplexobj(x):
type_ = x.type
dtype = type_.dtype
return "complex" in dtype
_solve_bilinear_direct_lyapunov = typing.cast(
typing.Callable, BilinearSolveDiscreteLyapunov()
)


def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
Expand All @@ -767,7 +759,7 @@ def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorV
AA = kron(A_, A_)

X = solve(pt.eye(AA.shape[0]) - AA, Q_.ravel())
return reshape(X, Q_.shape)
return typing.cast(TensorVariable, reshape(X, Q_.shape))


def solve_discrete_lyapunov(
Expand Down Expand Up @@ -803,7 +795,7 @@ def solve_discrete_lyapunov(
if method == "direct":
return _direct_solve_discrete_lyapunov(A, Q)
if method == "bilinear":
return _solve_bilinear_direct_lyapunov(A, Q)
return typing.cast(TensorVariable, _solve_bilinear_direct_lyapunov(A, Q))


def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
Expand All @@ -823,7 +815,90 @@ def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariabl

"""

return _solve_continuous_lyapunov(A, Q)
return typing.cast(TensorVariable, _solve_continuous_lyapunov(A, Q))


class SolveDiscreteARE(pt.Op):
__props__ = ("enforce_Q_symmetric",)

def __init__(self, enforce_Q_symmetric=False):
self.enforce_Q_symmetric = enforce_Q_symmetric

def make_node(self, A, B, Q, R):
A = as_tensor_variable(A)
B = as_tensor_variable(B)
Q = as_tensor_variable(Q)
R = as_tensor_variable(R)

out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype, Q.dtype, R.dtype)
X = pytensor.tensor.matrix(dtype=out_dtype)

return pytensor.graph.basic.Apply(self, [A, B, Q, R], [X])

def perform(self, node, inputs, output_storage):
A, B, Q, R = inputs
X = output_storage[0]

if self.enforce_Q_symmetric:
Q = 0.5 * (Q + Q.T)

X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
node.outputs[0].type.dtype
)

def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]

def grad(self, inputs, output_grads):
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
A, B, Q, R = inputs

(dX,) = output_grads
X = self(A, B, Q, R)

K_inner = R + pt.linalg.matrix_dot(B.T, X, B)
K_inner_inv = pt.linalg.solve(K_inner, pt.eye(R.shape[0]))
K = matrix_dot(K_inner_inv, B.T, X, A)

A_tilde = A - B.dot(K)

dX_symm = 0.5 * (dX + dX.T)
S = solve_discrete_lyapunov(A_tilde, dX_symm).astype(dX.type.dtype)

A_bar = 2 * matrix_dot(X, A_tilde, S)
B_bar = -2 * matrix_dot(X, A_tilde, S, K.T)
Q_bar = S
R_bar = matrix_dot(K, S, K.T)

return [A_bar, B_bar, Q_bar, R_bar]


def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
"""
Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`.

Parameters
----------
A: ArrayLike
Square matrix of shape M x M
B: ArrayLike
Square matrix of shape M x M
Q: ArrayLike
Symmetric square matrix of shape M x M
R: ArrayLike
Square matrix of shape N x N
enforce_Q_symmetric: bool
If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry

Returns
-------
X: pt.matrix
Square matrix of shape M x M, representing the solution to the DARE
"""

return typing.cast(
TensorVariable, SolveDiscreteARE(enforce_Q_symmetric)(A, B, Q, R)
)


__all__ = [
Expand All @@ -832,4 +907,8 @@ def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariabl
"eigvalsh",
"kron",
"expm",
"solve_discrete_lyapunov",
"solve_continuous_lyapunov",
"solve_discrete_are",
"solve_triangular",
]
57 changes: 53 additions & 4 deletions tests/tensor/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
kron,
solve,
solve_continuous_lyapunov,
solve_discrete_are,
solve_discrete_lyapunov,
solve_triangular,
)
Expand Down Expand Up @@ -532,7 +533,7 @@ def test_perform(self):
scipy_val = scipy.linalg.kron(a[np.newaxis, :], b).flatten()
else:
scipy_val = scipy.linalg.kron(a, b)
utt.assert_allclose(out, scipy_val)
np.testing.assert_allclose(out, scipy_val)

def test_numpy_2d(self):
for shp0 in [(2, 3)]:
Expand Down Expand Up @@ -564,7 +565,10 @@ def test_solve_discrete_lyapunov_via_direct_real():
utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)


@pytest.mark.filterwarnings("ignore::UserWarning")
def test_solve_discrete_lyapunov_via_direct_complex():
# Conj doesn't have C-op; filter the warning.

N = 5
rng = np.random.default_rng(utt.fetch_seed())
a = pt.zmatrix()
Expand All @@ -574,7 +578,7 @@ def test_solve_discrete_lyapunov_via_direct_complex():
A = rng.normal(size=(N, N)) + rng.normal(size=(N, N)) * 1j
Q = rng.normal(size=(N, N))
X = f(A, Q)
assert np.allclose(A @ X @ A.conj().T - X + Q, 0.0)
np.testing.assert_array_less(A @ X @ A.conj().T - X + Q, 1e-12)

# TODO: the .conj() method currently does not have a gradient; add this test when gradients are implemented.
# utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)
Expand All @@ -591,8 +595,8 @@ def test_solve_discrete_lyapunov_via_bilinear():
Q = rng.normal(size=(N, N))

X = f(A, Q)
assert np.allclose(A @ X @ A.conj().T - X + Q, 0.0)

np.testing.assert_array_less(A @ X @ A.conj().T - X + Q, 1e-12)
utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)


Expand All @@ -607,6 +611,51 @@ def test_solve_continuous_lyapunov():
Q = rng.normal(size=(N, N))
X = f(A, Q)

assert np.allclose(A @ X + X @ A.conj().T, Q)
Q_recovered = A @ X + X @ A.conj().T

np.testing.assert_allclose(Q_recovered.squeeze(), Q)
utt.verify_grad(solve_continuous_lyapunov, pt=[A, Q], rng=rng)


def test_solve_discrete_are_forward():
# TEST CASE 4 : darex #1 -- taken from Scipy tests
a, b, q, r = (
np.array([[4, 3], [-4.5, -3.5]]),
np.array([[1], [-1]]),
np.array([[9, 6], [6, 4]]),
np.array([[1]]),
)
a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r])

x = solve_discrete_are(a, b, q, r).eval()
res = a.T.dot(x.dot(a)) - x + q
res -= (
a.conj()
.T.dot(x.dot(b))
.dot(np.linalg.solve(r + b.conj().T.dot(x.dot(b)), b.T).dot(x.dot(a)))
)

atol = 1e-4 if config.floatX == "float32" else 1e-12
np.testing.assert_allclose(res, np.zeros_like(res), atol=atol)


def test_solve_discrete_are_grad():
a, b, q, r = (
np.array([[4, 3], [-4.5, -3.5]]),
np.array([[1], [-1]]),
np.array([[9, 6], [6, 4]]),
np.array([[1]]),
)
a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r])

rng = np.random.default_rng(utt.fetch_seed())

# TODO: Is there a "theoretically motivated" value to use here? I pulled 1e-4 out of a hat
atol = 1e-4 if config.floatX == "float32" else 1e-12

utt.verify_grad(
functools.partial(solve_discrete_are, enforce_Q_symmetric=True),
pt=[a, b, q, r],
rng=rng,
abs_tol=atol,
)
Loading