Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made a kalman_filter method similar to kalman_smooth. #235

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions examples/kalman_tracking_2d_example.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
using LinearAlgebra, Distributions

function make_params()
F = [1 0 1 0; 0 1 0 1; 0 0 1 0; 0 0 0 1.0];
H = [1.0 0 0 0; 0 1 0 0];
nobs, nhidden = size(H)
Q = Matrix(I, nhidden, nhidden) .* 0.001
R = Matrix(I, nobs, nobs) .* 0.1 # 1.0
mu0 = [8, 10, 1, 0.0];
V0 = Matrix(I, nhidden, nhidden) .* 1.0
params = (mu0 = mu0, V0 = V0, F = F, H = H, Q = Q, R = R)
return params
end

# https://github.com/probml/pmtk3/blob/master/matlabTools/graphics/gaussPlot2d.m
function plot_gauss2d(m, C)
U = eigvecs(C)
D = eigvals(C)
N = 100
t = range(0, stop=2*pi, length=N)
xy = zeros(Float64, 2, N)
xy[1,:] = cos.(t)
xy[2,:] = sin.(t)
#k = sqrt(6) # approx sqrt(chi2inv(0.95, 2)) = 2.45
k = 1.0
w = (k * U * Diagonal(sqrt.(D))) * xy # 2*N
#Plots.scatter!([m[1]], [m[2]], marker=:star, label="")
handle = Plots.plot!(w[1,:] .+ m[1], w[2,:] .+ m[2], label="")
return handle
end

function do_plot(zs, ys, m, V)
# m is H*T, V is H*H*T, where H=4 hidden states
plt = scatter(ys[1,:], ys[2,:], label="observed", reuse=false)
plt = scatter!(zs[1,:], zs[2,:], label="true", marker=:star)
xlims!(minimum(ys[1,:])-1, maximum(ys[1,:])+1)
ylims!(minimum(ys[2,:])-1, maximum(ys[2,:])+1)
display(plt)
m2 = m[1:2,:]
V2 = V[1:2, 1:2, :]
T = size(m2, 2)
for t=1:T
plt = plot_gauss2d(m2[:,t], V2[:,:,t])
end
display(plt)
end


Random.seed!(2)
T = 10
params = make_params()
F = params.F; H = params.H; Q = params.Q; R = params.R; mu0 = params.mu0; V0 = params.V0;
kf = Kalman(F, H, Q, R)
set_state!(kf, mu0, V0)
zs, ys = kalman_sample(kf, T) # H*T, O*T
println("inference")
set_state!(kf, mu0, V0)
mF, loglik, VF = kalman_filter(kf, ys)
set_state!(kf, mu0, V0)
mS, loglik, VS = kalman_smoother(kf, ys)

println("plotting")
using Plots; pyplot()
closeall()
do_plot(zs, ys, mF, VF); title!("Filtering")
do_plot(zs, ys, mS, VS); title!("Smoothing")
79 changes: 63 additions & 16 deletions src/kalman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ https://lectures.quantecon.org/jl/kalman.html
TODO: Do docstrings here after implementing LinerStateSpace
=#

import Distributions

mutable struct Kalman
A
G
Expand Down Expand Up @@ -42,25 +44,26 @@ function set_state!(k::Kalman, x_hat, Sigma)
Nothing
end

@doc doc"""
#=
"""
Updates the moments (`cur_x_hat`, `cur_sigma`) of the time ``t`` prior to the
time ``t`` filtering distribution, using current measurement ``y_t``.
The updates are according to

```math
`math
\hat{x}^F = \hat{x} + \Sigma G' (G \Sigma G' + R)^{-1}
(y - G \hat{x}) \\

\Sigma^F = \Sigma - \Sigma G' (G \Sigma G' + R)^{-1} G
\Sigma
```
`

#### Arguments

- `k::Kalman` An instance of the Kalman filter
- `y` The current measurement

"""
=#
function prior_to_filtered!(k::Kalman, y)
# simplify notation
G, R = k.G, k.R
Expand Down Expand Up @@ -173,34 +176,52 @@ end
##### Arguments
- `kn::Kalman`: `Kalman` specifying the model. Initial value must be the prior
for t=1 period observation, i.e. ``x_{1|0}``.
- `y::AbstractMatrix`: `n x T` matrix of observed data.
`n` is the number of observed variables in one period.
- `y::AbstractMatrix`: `k x T` matrix of observed data.
`k` is the number of observed variables in one period.
Each column is a vector of observations at each period.

##### Returns
- `x_smoothed::AbstractMatrix`: `k x T` matrix of smoothed mean of states.
`k` is the number of states.
- `x_filtered::AbstractMatrix`: `n x T` matrix of filtered mean of states.
`n` is the number of states.
- `logL::Real`: log-likelihood of all observations
- `sigma_smoothed::AbstractArray` `k x k x T` array of smoothed covariance matrix of states.
- `sigma_filtered::AbstractArray` `n x n x T` array of filtered covariance matrix of states.
- `sigma_forecast::AbstractArray` `n x n x T` array of predictive covariance matrix of states.
"""
function smooth(kn::Kalman, y::AbstractMatrix)
G, R = kn.G, kn.R

function filter(kn::Kalman, y::AbstractMatrix)
T = size(y, 2)
n = kn.n
k, n = size(kn.G)
@assert n == kn.n
x_filtered = Matrix{Float64}(undef, n, T)
sigma_filtered = Array{Float64}(undef, n, n, T)
sigma_forecast = Array{Float64}(undef, n, n, T)
logL = 0
# forecast and update
for t in 1:T
logL = logL + log_likelihood(kn, y[:, t])
prior_to_filtered!(kn, y[:, t])
x_filtered[:, t], sigma_filtered[:, :, t] = kn.cur_x_hat, kn.cur_sigma
filtered_to_forecast!(kn)
sigma_forecast[:, :, t] = kn.cur_sigma
end
# smoothing
return x_filtered, logL, sigma_filtered, sigma_forecast
end

"""
##### Arguments
- `kn::Kalman`: `Kalman` specifying the model. Initial value must be the prior
for t=1 period observation, i.e. ``x_{1|0}``.
- `y::AbstractMatrix`: `k x T` matrix of observed data.
`k` is the number of observed variables in one period.
Each column is a vector of observations at each period.

##### Returns
- `x_smoothed::AbstractMatrix`: `n x T` matrix of smoothed mean of states.
`n` is the number of states.
- `logL::Real`: log-likelihood of all observations
- `sigma_smoothed::AbstractArray` `n x n x T` array of smoothed covariance matrix of states.
"""
function smooth(kn::Kalman, y::AbstractMatrix)
T = size(y, 2)
x_filtered, logL, sigma_filtered, sigma_forecast = filter(kn, y)
x_smoothed = copy(x_filtered)
sigma_smoothed = copy(sigma_filtered)
for t in (T-1):-1:1
Expand All @@ -209,7 +230,6 @@ function smooth(kn::Kalman, y::AbstractMatrix)
sigma_forecast[:, :, t], x_smoothed[:, t+1],
sigma_smoothed[:, :, t+1])
end

return x_smoothed, logL, sigma_smoothed
end

Expand All @@ -236,3 +256,30 @@ function go_backward(k::Kalman, x_fi::Vector,
sigma_s = sigma_fi + temp*(sigma_s1-sigma_fo)*temp'
return x_s, sigma_s
end

"""
##### Arguments
- `kn::Kalman`: `Kalman` specifying the model.
- `T`: number of time steps to sample for

##### Returns
- `xs::Matrix`: `xs[:,t]` is sampled hidden state for period t
- `ys::Matrix`: `ys[:,t]` is sampled observation for period `t`
"""
function sample(kn::Kalman, T::Int)
nobs, nhidden = size(kn.G)
xs = Array{Float64}(undef, nhidden, T)
ys = Array{Float64}(undef, nobs, T)
mu0 = kn.cur_x_hat
V0 = kn.cur_sigma
prior_z = Distributions.MvNormal(mu0, V0)
process_noise_dist = Distributions.MvNormal(kn.Q)
obs_noise_dist = Distributions.MvNormal(kn.R)
xs[:,1] = rand(prior_z)
ys[:,1] = kn.G*xs[:,1] + rand(obs_noise_dist)
for t=2:T
xs[:,t] = kn.A*xs[:,t-1] + rand(process_noise_dist)
ys[:,t] = kn.G*xs[:,t] + rand(obs_noise_dist)
end
return xs, ys
end
45 changes: 40 additions & 5 deletions test/test_kalman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@
# Mdl = ssm(A,B_sigma,C,D);
# [x_matlab, logL_matlab] = smooth(Mdl, y)
# ```



A = [.5 .4;
.3 .2]
Q = [.34 .17;
Expand All @@ -85,12 +88,44 @@
k = Kalman(A, G, Q, R)
cov_init = [0.722222222222222 0.386904761904762;
0.386904761904762 0.293154761904762]
#set_state!(k, zeros(2), cov_init)
#x_filtered, logL_filtered, P_filtered, P_predictive = filter(k, y)
set_state!(k, zeros(2), cov_init)
x_smoothed, logL, P_smoothed = smooth(k, y)
x_matlab = [1.36158275104493 2.68312458668362 4.04291315305382 5.36947053521018;
0.813542618042249 1.64113106904578 2.43805629027213 3.22585113133984]
x_smoothed, logL_smoothed, P_smoothed = smooth(k, y)

# We verify the Julia results match the Matlab code below
#https://github.com/probml/pmtk3/blob/master/demos/kalman_test.m
xfilt_matlab =
[1.3409 2.6585 4.0142 5.3695;
0.8076 1.6334 2.4293 3.2259];
xsmooth_matlab =
[1.3616 2.6831 4.0429 5.3695;
0.8135 1.6411 2.4381 3.2259];
logL_matlab = -22.1434290195012
@test isapprox(x_smoothed, x_matlab)
@test isapprox(logL, logL_matlab)
#@test isapprox(x_filtered, xfilt_matlab; rough_kwargs...)
@test isapprox(x_smoothed, xsmooth_matlab; rough_kwargs...)
#@test isapprox(logL_filtered, logL_matlab; rough_kwargs...)
@test isapprox(logL_smoothed, logL_matlab; rough_kwargs...)

#=
set_state!(k, zeros(2), cov_init)
xs, ys = sample(k, 10)
@test size(xs) == (2,10)
@test size(ys) == (1,10)
=#

#=
N = 5000
x1 = zeros(2,N)
for n in 1:N
#global x1 # only need 'global' when running in REPL
xs, ys = sample(k, 2)
x1[:,n] = xs[:,1]
end
m=StatsBase.mean(x1, dims=2);
@test isapprox(m, zeros(2); atol=1e-1)
C=StatsBase.cov(x1');
@test isapprox(C, cov_init; atol=1e-1)
=#

end # @testset