diff --git a/examples/kalman_tracking_2d_example.jl b/examples/kalman_tracking_2d_example.jl new file mode 100644 index 00000000..9beada9c --- /dev/null +++ b/examples/kalman_tracking_2d_example.jl @@ -0,0 +1,66 @@ +using LinearAlgebra, Distributions + +function make_params() + F = [1 0 1 0; 0 1 0 1; 0 0 1 0; 0 0 0 1.0]; + H = [1.0 0 0 0; 0 1 0 0]; + nobs, nhidden = size(H) + Q = Matrix(I, nhidden, nhidden) .* 0.001 + R = Matrix(I, nobs, nobs) .* 0.1 # 1.0 + mu0 = [8, 10, 1, 0.0]; + V0 = Matrix(I, nhidden, nhidden) .* 1.0 + params = (mu0 = mu0, V0 = V0, F = F, H = H, Q = Q, R = R) + return params +end + +# https://github.com/probml/pmtk3/blob/master/matlabTools/graphics/gaussPlot2d.m +function plot_gauss2d(m, C) + U = eigvecs(C) + D = eigvals(C) + N = 100 + t = range(0, stop=2*pi, length=N) + xy = zeros(Float64, 2, N) + xy[1,:] = cos.(t) + xy[2,:] = sin.(t) + #k = sqrt(6) # approx sqrt(chi2inv(0.95, 2)) = 2.45 + k = 1.0 + w = (k * U * Diagonal(sqrt.(D))) * xy # 2*N + #Plots.scatter!([m[1]], [m[2]], marker=:star, label="") + handle = Plots.plot!(w[1,:] .+ m[1], w[2,:] .+ m[2], label="") + return handle +end + +function do_plot(zs, ys, m, V) + # m is H*T, V is H*H*T, where H=4 hidden states + plt = scatter(ys[1,:], ys[2,:], label="observed", reuse=false) + plt = scatter!(zs[1,:], zs[2,:], label="true", marker=:star) + xlims!(minimum(ys[1,:])-1, maximum(ys[1,:])+1) + ylims!(minimum(ys[2,:])-1, maximum(ys[2,:])+1) + display(plt) + m2 = m[1:2,:] + V2 = V[1:2, 1:2, :] + T = size(m2, 2) + for t=1:T + plt = plot_gauss2d(m2[:,t], V2[:,:,t]) + end + display(plt) +end + + +Random.seed!(2) +T = 10 +params = make_params() +F = params.F; H = params.H; Q = params.Q; R = params.R; mu0 = params.mu0; V0 = params.V0; +kf = Kalman(F, H, Q, R) +set_state!(kf, mu0, V0) +zs, ys = kalman_sample(kf, T) # H*T, O*T +println("inference") +set_state!(kf, mu0, V0) +mF, loglik, VF = kalman_filter(kf, ys) +set_state!(kf, mu0, V0) +mS, loglik, VS = kalman_smoother(kf, ys) + +println("plotting") +using Plots; pyplot() +closeall() +do_plot(zs, ys, mF, VF); title!("Filtering") +do_plot(zs, ys, mS, VS); title!("Smoothing") diff --git a/src/kalman.jl b/src/kalman.jl index 7b5bbdfc..fb0af020 100644 --- a/src/kalman.jl +++ b/src/kalman.jl @@ -14,6 +14,8 @@ https://lectures.quantecon.org/jl/kalman.html TODO: Do docstrings here after implementing LinerStateSpace =# +import Distributions + mutable struct Kalman A G @@ -42,25 +44,26 @@ function set_state!(k::Kalman, x_hat, Sigma) Nothing end -@doc doc""" +#= +""" Updates the moments (`cur_x_hat`, `cur_sigma`) of the time ``t`` prior to the time ``t`` filtering distribution, using current measurement ``y_t``. The updates are according to -```math +`math \hat{x}^F = \hat{x} + \Sigma G' (G \Sigma G' + R)^{-1} (y - G \hat{x}) \\ \Sigma^F = \Sigma - \Sigma G' (G \Sigma G' + R)^{-1} G \Sigma -``` +` #### Arguments - `k::Kalman` An instance of the Kalman filter - `y` The current measurement - """ +=# function prior_to_filtered!(k::Kalman, y) # simplify notation G, R = k.G, k.R @@ -173,26 +176,25 @@ end ##### Arguments - `kn::Kalman`: `Kalman` specifying the model. Initial value must be the prior for t=1 period observation, i.e. ``x_{1|0}``. -- `y::AbstractMatrix`: `n x T` matrix of observed data. - `n` is the number of observed variables in one period. +- `y::AbstractMatrix`: `k x T` matrix of observed data. + `k` is the number of observed variables in one period. Each column is a vector of observations at each period. ##### Returns -- `x_smoothed::AbstractMatrix`: `k x T` matrix of smoothed mean of states. - `k` is the number of states. +- `x_filtered::AbstractMatrix`: `n x T` matrix of filtered mean of states. + `n` is the number of states. - `logL::Real`: log-likelihood of all observations -- `sigma_smoothed::AbstractArray` `k x k x T` array of smoothed covariance matrix of states. +- `sigma_filtered::AbstractArray` `n x n x T` array of filtered covariance matrix of states. +- `sigma_forecast::AbstractArray` `n x n x T` array of predictive covariance matrix of states. """ -function smooth(kn::Kalman, y::AbstractMatrix) - G, R = kn.G, kn.R - +function filter(kn::Kalman, y::AbstractMatrix) T = size(y, 2) - n = kn.n + k, n = size(kn.G) + @assert n == kn.n x_filtered = Matrix{Float64}(undef, n, T) sigma_filtered = Array{Float64}(undef, n, n, T) sigma_forecast = Array{Float64}(undef, n, n, T) logL = 0 - # forecast and update for t in 1:T logL = logL + log_likelihood(kn, y[:, t]) prior_to_filtered!(kn, y[:, t]) @@ -200,7 +202,26 @@ function smooth(kn::Kalman, y::AbstractMatrix) filtered_to_forecast!(kn) sigma_forecast[:, :, t] = kn.cur_sigma end - # smoothing + return x_filtered, logL, sigma_filtered, sigma_forecast +end + +""" +##### Arguments +- `kn::Kalman`: `Kalman` specifying the model. Initial value must be the prior + for t=1 period observation, i.e. ``x_{1|0}``. +- `y::AbstractMatrix`: `k x T` matrix of observed data. + `k` is the number of observed variables in one period. + Each column is a vector of observations at each period. + +##### Returns +- `x_smoothed::AbstractMatrix`: `n x T` matrix of smoothed mean of states. + `n` is the number of states. +- `logL::Real`: log-likelihood of all observations +- `sigma_smoothed::AbstractArray` `n x n x T` array of smoothed covariance matrix of states. +""" +function smooth(kn::Kalman, y::AbstractMatrix) + T = size(y, 2) + x_filtered, logL, sigma_filtered, sigma_forecast = filter(kn, y) x_smoothed = copy(x_filtered) sigma_smoothed = copy(sigma_filtered) for t in (T-1):-1:1 @@ -209,7 +230,6 @@ function smooth(kn::Kalman, y::AbstractMatrix) sigma_forecast[:, :, t], x_smoothed[:, t+1], sigma_smoothed[:, :, t+1]) end - return x_smoothed, logL, sigma_smoothed end @@ -236,3 +256,30 @@ function go_backward(k::Kalman, x_fi::Vector, sigma_s = sigma_fi + temp*(sigma_s1-sigma_fo)*temp' return x_s, sigma_s end + +""" +##### Arguments +- `kn::Kalman`: `Kalman` specifying the model. +- `T`: number of time steps to sample for + +##### Returns +- `xs::Matrix`: `xs[:,t]` is sampled hidden state for period t +- `ys::Matrix`: `ys[:,t]` is sampled observation for period `t` +""" +function sample(kn::Kalman, T::Int) + nobs, nhidden = size(kn.G) + xs = Array{Float64}(undef, nhidden, T) + ys = Array{Float64}(undef, nobs, T) + mu0 = kn.cur_x_hat + V0 = kn.cur_sigma + prior_z = Distributions.MvNormal(mu0, V0) + process_noise_dist = Distributions.MvNormal(kn.Q) + obs_noise_dist = Distributions.MvNormal(kn.R) + xs[:,1] = rand(prior_z) + ys[:,1] = kn.G*xs[:,1] + rand(obs_noise_dist) + for t=2:T + xs[:,t] = kn.A*xs[:,t-1] + rand(process_noise_dist) + ys[:,t] = kn.G*xs[:,t] + rand(obs_noise_dist) + end + return xs, ys +end diff --git a/test/test_kalman.jl b/test/test_kalman.jl index 0d8a7670..f9a53dae 100644 --- a/test/test_kalman.jl +++ b/test/test_kalman.jl @@ -75,6 +75,9 @@ # Mdl = ssm(A,B_sigma,C,D); # [x_matlab, logL_matlab] = smooth(Mdl, y) # ``` + + + A = [.5 .4; .3 .2] Q = [.34 .17; @@ -85,12 +88,44 @@ k = Kalman(A, G, Q, R) cov_init = [0.722222222222222 0.386904761904762; 0.386904761904762 0.293154761904762] + #set_state!(k, zeros(2), cov_init) + #x_filtered, logL_filtered, P_filtered, P_predictive = filter(k, y) set_state!(k, zeros(2), cov_init) - x_smoothed, logL, P_smoothed = smooth(k, y) - x_matlab = [1.36158275104493 2.68312458668362 4.04291315305382 5.36947053521018; - 0.813542618042249 1.64113106904578 2.43805629027213 3.22585113133984] + x_smoothed, logL_smoothed, P_smoothed = smooth(k, y) + + # We verify the Julia results match the Matlab code below + #https://github.com/probml/pmtk3/blob/master/demos/kalman_test.m + xfilt_matlab = + [1.3409 2.6585 4.0142 5.3695; + 0.8076 1.6334 2.4293 3.2259]; + xsmooth_matlab = + [1.3616 2.6831 4.0429 5.3695; + 0.8135 1.6411 2.4381 3.2259]; logL_matlab = -22.1434290195012 - @test isapprox(x_smoothed, x_matlab) - @test isapprox(logL, logL_matlab) + #@test isapprox(x_filtered, xfilt_matlab; rough_kwargs...) + @test isapprox(x_smoothed, xsmooth_matlab; rough_kwargs...) + #@test isapprox(logL_filtered, logL_matlab; rough_kwargs...) + @test isapprox(logL_smoothed, logL_matlab; rough_kwargs...) + + #= + set_state!(k, zeros(2), cov_init) + xs, ys = sample(k, 10) + @test size(xs) == (2,10) + @test size(ys) == (1,10) + =# + + #= + N = 5000 + x1 = zeros(2,N) + for n in 1:N + #global x1 # only need 'global' when running in REPL + xs, ys = sample(k, 2) + x1[:,n] = xs[:,1] + end + m=StatsBase.mean(x1, dims=2); + @test isapprox(m, zeros(2); atol=1e-1) + C=StatsBase.cov(x1'); + @test isapprox(C, cov_init; atol=1e-1) + =# end # @testset