Skip to content

Commit

Permalink
Riccati: Add option to use scipy.linalg.solve_discrete_are
Browse files Browse the repository at this point in the history
Close #360
  • Loading branch information
oyamad committed Oct 24, 2017
1 parent 6b1f9a8 commit 6858018
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 51 deletions.
46 changes: 30 additions & 16 deletions quantecon/lqcontrol.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class LQ:
x_{t+1} = A x_t + B u_t + C w_{t+1}
Here :math:`x` is n x 1, :math:`u` is k x 1, :math:`w` is j x 1 and the
Here :math:`x` is n x 1, :math:`u` is k x 1, :math:`w` is j x 1 and the
matrices are conformable for these dimensions. The sequence :math:`{w_t}`
is assumed to be white noise, with zero mean and
:math:`\mathbb{E} [ w_t' w_t ] = I`, the j x j identity.
Expand All @@ -68,16 +68,13 @@ class LQ:
Parameters
----------
Q : array_like(float)
Q is the payoff(or cost) matrix that corresponds with the
Q is the payoff (or cost) matrix that corresponds with the
control variable u and is k x k. Should be symmetric and
nonnegative definite
non-negative definite
R : array_like(float)
R is the payoff(or cost) matrix that corresponds with the
R is the payoff (or cost) matrix that corresponds with the
state variable x and is n x n. Should be symetric and
non-negative definite
N : array_like(float)
N is the cross product term in the payoff, as above. It should
be k x n.
A : array_like(float)
A is part of the state transition as described above. It should
be n x n
Expand All @@ -88,6 +85,9 @@ class LQ:
C is part of the state transition as described above and
corresponds to the random variable today. If the model is
deterministic then C should take default value of None
N : array_like(float), optional(default=None)
N is the cross product term in the payoff, as above. It should
be k x n.
beta : scalar(float), optional(default=1)
beta is the discount parameter
T : scalar(int), optional(default=None)
Expand All @@ -97,7 +97,6 @@ class LQ:
matrix that corresponds with the control variable u and is n x
n. Should be symetric and non-negative definite
Attributes
----------
Q, R, N, A, B, C, beta, T, Rf : see Parameters
Expand Down Expand Up @@ -197,17 +196,26 @@ def update_values(self):
# == Set new state == #
self.P, self.d = new_P, new_d

def stationary_values(self):
def stationary_values(self, method='doubling'):
"""
Computes the matrix :math:`P` and scalar :math:`d` that represent
Computes the matrix :math:`P` and scalar :math:`d` that represent
the value function
.. math::
V(x) = x' P x + d
in the infinite horizon case. Also computes the control matrix
:math:`F` from :math:`u = - Fx`
:math:`F` from :math:`u = - Fx`. Computation is via the solution
algorithm as specified by the `method` option (default to the
doubling algorithm) (see the documentation in
`matrix_eqn.solve_discrete_riccati`).
Parameters
----------
method : str, optional(default='doubling')
Solution method used in solving the associated Riccati
equation, str in {'doubling', 'qz'}.
Returns
-------
Expand All @@ -227,7 +235,7 @@ def stationary_values(self):

# === solve Riccati equation, obtain P === #
A0, B0 = np.sqrt(self.beta) * A, np.sqrt(self.beta) * B
P = solve_discrete_riccati(A0, B0, R, Q, N)
P = solve_discrete_riccati(A0, B0, R, Q, N, method=method)

# == Compute F == #
S1 = Q + self.beta * dot(B.T, dot(P, B))
Expand All @@ -242,28 +250,34 @@ def stationary_values(self):

return P, F, d

def compute_sequence(self, x0, ts_length=None, random_state=None):
def compute_sequence(self, x0, ts_length=None, method='doubling',
random_state=None):
"""
Compute and return the optimal state and control sequences
:math:`x_0, ..., x_T` and :math:`u_0,..., u_T` under the
assumption that :math:`{w_t}` is iid and :math:`N(0, 1)`.
Parameters
===========
----------
x0 : array_like(float)
The initial state, a vector of length n
ts_length : scalar(int)
Length of the simulation -- defaults to T in finite case
method : str, optional(default='doubling')
Solution method used in solving the associated Riccati
equation, str in {'doubling', 'qz'}. Only relevant when the
`T` attribute is `None` (i.e., the horizon is infinite).
random_state : int or np.random.RandomState, optional
Random seed (integer) or np.random.RandomState instance to set
the initial state of the random number generator for
reproducibility. If None, a randomly initialized RandomState is
used.
Returns
========
-------
x_path : array_like(float)
An n x T+1 matrix, where the t-th column represents :math:`x_t`
Expand All @@ -286,7 +300,7 @@ def compute_sequence(self, x0, ts_length=None, random_state=None):
# == Preliminaries, infinite horizon case == #
else:
T = ts_length if ts_length else 100
self.stationary_values()
self.stationary_values(method=method)

# == Set up initial condition and arrays to store paths == #
random_state = check_random_state(random_state)
Expand Down
40 changes: 27 additions & 13 deletions quantecon/matrix_eqn.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""
Filename: matrix_eqn.py
This files holds several functions that are used to solve matrix
This file holds several functions that are used to solve matrix
equations. Currently has functionality to solve:
* Lyapunov Equations
* Ricatti Equations
* Riccati Equations
TODO: 1. See issue 47 on github repository, should add support for
Sylvester equations
Expand All @@ -16,6 +16,7 @@
from numpy import dot
from numpy.linalg import solve
from scipy.linalg import solve_discrete_lyapunov as sp_solve_discrete_lyapunov
from scipy.linalg import solve_discrete_are as sp_solve_discrete_are


EPS = np.finfo(float).eps
Expand Down Expand Up @@ -60,9 +61,9 @@ def solve_discrete_lyapunov(A, B, max_it=50, method="doubling"):
approach.
Returns
========
-------
gamma1: array_like(float, ndim=2)
Represents the value :math:`V`
Represents the value :math:`X`
"""
if method == "doubling":
Expand Down Expand Up @@ -98,21 +99,19 @@ def solve_discrete_lyapunov(A, B, max_it=50, method="doubling"):
return gamma1


def solve_discrete_riccati(A, B, Q, R, N=None, tolerance=1e-10, max_iter=500):
def solve_discrete_riccati(A, B, Q, R, N=None, tolerance=1e-10, max_iter=500,
method="doubling"):
"""
Solves the discrete-time algebraic Riccati equation
.. math::
X = A'XA - (N + B'XA)'(B'XB + R)^{-1}(N + B'XA) + Q
via a modified structured doubling algorithm. An explanation of the
algorithm can be found in the reference below.
Note that SciPy also has a discrete riccati equation solver. However it
cannot handle the case where :math:`R` is not invertible, or when :math:`N`
is nonzero. Both of these cases can be handled in the algorithm implemented
below.
Computation is via a modified structured doubling algorithm, an
explanation of which can be found in the reference below, if
`method="doubling"` (default), and via a QZ decomposition method by
calling `scipy.linalg.solve_discrete_are` if `method="qz"`.
Parameters
----------
Expand All @@ -130,11 +129,16 @@ def solve_discrete_riccati(A, B, Q, R, N=None, tolerance=1e-10, max_iter=500):
The tolerance level for convergence
max_iter : scalar(int), optional(default=500)
The maximum number of iterations allowed
method : string, optional(default="doubling")
Describes the solution method to use. If it is "doubling" then
uses the doubling algorithm to solve, if it is "qz" then it uses
`scipy.linalg.solve_discrete_are` (in which case `tolerance` and
`max_iter` are irrelevant).
Returns
-------
X : array_like(float, ndim=2)
The fixed point of the Riccati equation; a k x k array
The fixed point of the Riccati equation; a k x k array
representing the approximate solution
References
Expand All @@ -145,6 +149,11 @@ def solve_discrete_riccati(A, B, Q, R, N=None, tolerance=1e-10, max_iter=500):
(2010): pp-935.
"""
methods = ['doubling', 'qz']
if method not in methods:
msg = "Check your method input. Should be {} or {}".format(*methods)
raise ValueError(msg)

# == Set up == #
error = tolerance + 1
fail_msg = "Convergence failed after {} iterations."
Expand All @@ -158,6 +167,11 @@ def solve_discrete_riccati(A, B, Q, R, N=None, tolerance=1e-10, max_iter=500):
else:
N = np.atleast_2d(N)

if method == 'qz':
X = sp_solve_discrete_are(A, B, Q, R, s=N.T)
return X

# if method == 'doubling'
# == Choose optimal value of gamma in R_hat = R + gamma B'B == #
current_min = np.inf
candidates = (0.0, 0.01, 0.1, 0.25, 0.5, 1.0, 2.0, 10.0, 100.0, 10e5)
Expand Down
21 changes: 11 additions & 10 deletions quantecon/tests/test_lqcontrol.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def setUp(self):

self.lq_scalar = LQ(q, r, a, b, C=c, beta=beta, T=T, Rf=rf)


Q = np.array([[0., 0.], [0., 1]])
R = np.array([[1., 0.], [0., 0]])
RF = np.eye(2) * 100
Expand All @@ -39,12 +38,12 @@ def setUp(self):

self.lq_mat = LQ(Q, R, A, B, beta=beta, T=T, Rf=RF)

self.methods = ['doubling', 'qz']

def tearDown(self):
del self.lq_scalar
del self.lq_mat


def test_scalar_sequences(self):

lq_scalar = self.lq_scalar
Expand All @@ -56,15 +55,17 @@ def test_scalar_sequences(self):
u_0 = (-2*lq_scalar.A*lq_scalar.B*lq_scalar.beta*lq_scalar.Rf) / \
(2*lq_scalar.Q+lq_scalar.beta*lq_scalar.Rf*2*lq_scalar.B**2) \
* x0
x_1 = lq_scalar.A * x0 + lq_scalar.B * u_0 + dot(lq_scalar.C, w_seq[0, -1])
x_1 = lq_scalar.A * x0 + lq_scalar.B * u_0 + \
dot(lq_scalar.C, w_seq[0, -1])

assert_allclose(u_0, u_seq, rtol=1e-4)
assert_allclose(x_1, x_seq[0, -1], rtol=1e-4)

def test_scalar_sequences_with_seed(self):
lq_scalar = self.lq_scalar
x0 = 2
x_seq, u_seq, w_seq = lq_scalar.compute_sequence(x0, 10, 5)
x_seq, u_seq, w_seq = \
lq_scalar.compute_sequence(x0, 10, random_state=5)

expected_output = np.array([[ 0.44122749, -0.33087015]])

Expand All @@ -80,20 +81,20 @@ def test_mat_sequences(self):
assert_allclose(np.sum(u_seq), .95 * np.sum(x0), atol=1e-3)
assert_allclose(x_seq[:, -1], np.zeros_like(x0), atol=1e-3)


def test_stationary_mat(self):
x0 = np.random.randn(2) * 25
lq_mat = self.lq_mat

P, F, d = lq_mat.stationary_values()
f_answer = np.array([[-.95, -.95], [0., 0.]])
p_answer = np.array([[1., 0], [0., 0.]])

val_func_lq = np.dot(x0, P).dot(x0)
val_func_answer = x0[0]**2

assert_allclose(f_answer, F, atol=1e-3)
assert_allclose(val_func_lq, val_func_answer, atol=1e-3)
for method in self.methods:
P, F, d = lq_mat.stationary_values(method=method)
val_func_lq = np.dot(x0, P).dot(x0)

assert_allclose(f_answer, F, atol=1e-3)
assert_allclose(val_func_lq, val_func_answer, atol=1e-3)


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit 6858018

Please sign in to comment.