Skip to content

Commit

Permalink
Extend unit tests for GPEmulator.jl and Observations.jl
Browse files Browse the repository at this point in the history
Update src/Observations.jl

Co-authored-by: Charles Kawczynski <[email protected]>
  • Loading branch information
Melanie Bieli and charleskawczynski committed Jun 27, 2020
1 parent 5fb71c2 commit a163957
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 96 deletions.
103 changes: 51 additions & 52 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ version = "0.3.3"

[[Adapt]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "fd04049c7dd78cfef0b06cdc1f0f181467655712"
git-tree-sha1 = "0fac443759fa829ed8066db6cf1077d888bb6573"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "1.1.0"
version = "2.0.2"

[[ArnoldiMethod]]
deps = ["DelimitedFiles", "LinearAlgebra", "Random", "SparseArrays", "StaticArrays", "Test"]
Expand All @@ -38,21 +38,21 @@ version = "3.5.0+3"

[[ArrayInterface]]
deps = ["LinearAlgebra", "Requires", "SparseArrays"]
git-tree-sha1 = "649c08a5a3a513f4662673d3777fe6ccb4df9f5d"
git-tree-sha1 = "851de9a8acd7b8863aa2ec2af0a44f375502c878"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "2.8.7"
version = "2.9.0"

[[ArrayLayouts]]
deps = ["FillArrays", "LinearAlgebra"]
git-tree-sha1 = "89182776a99b69964e995cc2f1e37b5fc3476d56"
git-tree-sha1 = "a3254b3780a3544838ca0b7e23b1e9b06eb71bd8"
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
version = "0.3.4"
version = "0.3.5"

[[BandedMatrices]]
deps = ["ArrayLayouts", "FillArrays", "LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "fd300e252fa1d96c75884cfa37fd6a5402c79d4b"
git-tree-sha1 = "195ceb173f0759ca595770fac3b379e51579e5e7"
uuid = "aae01518-5342-5314-be14-df237901396f"
version = "0.15.12"
version = "0.15.13"

[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
Expand Down Expand Up @@ -94,15 +94,15 @@ version = "0.8.1"

[[ChainRules]]
deps = ["ChainRulesCore", "LinearAlgebra", "Reexport", "Requires", "Statistics"]
git-tree-sha1 = "85f130f2c5ce208a5a395b550802398d2fcc5ee6"
git-tree-sha1 = "76cd719cb7ab57bd2687dcb3b186c4f99820a79d"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.6.4"
version = "0.6.5"

[[ChainRulesCore]]
deps = ["MuladdMacro"]
git-tree-sha1 = "32e2c6e44d4fdd985b5688b5e85c1f6892cf3d15"
git-tree-sha1 = "c384e0e4fe6bfeb6bec0d41f71cc5e391cd110ba"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.8.0"
version = "0.8.1"

[[Cloudy]]
deps = ["Coverage", "DifferentialEquations", "DocStringExtensions", "ForwardDiff", "HCubature", "LinearAlgebra", "Optim", "PyPlot", "SpecialFunctions", "TaylorSeries", "Test"]
Expand Down Expand Up @@ -187,15 +187,15 @@ version = "1.3.0"

[[DataFrames]]
deps = ["CategoricalArrays", "Compat", "DataAPI", "Future", "InvertedIndices", "IteratorInterfaceExtensions", "Missings", "PooledArrays", "Printf", "REPL", "Reexport", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"]
git-tree-sha1 = "02f08ae77249b7f6d4186b081a016fb7454c616f"
git-tree-sha1 = "e516e72bfb40809b7709cda7bfb39e82ec492d68"
uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
version = "0.21.2"
version = "0.21.3"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "be680f1ad03c0a03796aa3fda5a2180df7f83b46"
git-tree-sha1 = "edad9434967fdc0a2631a65d902228400642120c"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.17.18"
version = "0.17.19"

[[DataValueInterfaces]]
git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
Expand All @@ -218,9 +218,9 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[DiffEqBase]]
deps = ["ArrayInterface", "ChainRulesCore", "ConsoleProgressMonitor", "DataStructures", "Distributed", "DocStringExtensions", "FunctionWrappers", "IterativeSolvers", "IteratorInterfaceExtensions", "LabelledArrays", "LinearAlgebra", "Logging", "LoggingExtras", "MuladdMacro", "Parameters", "Printf", "ProgressLogging", "RecipesBase", "RecursiveArrayTools", "RecursiveFactorization", "Requires", "Roots", "SparseArrays", "StaticArrays", "Statistics", "SuiteSparse", "TableTraits", "TerminalLoggers", "TreeViews", "ZygoteRules"]
git-tree-sha1 = "ae65fac7d9933f3d039c0296b5d41bf8c3d8f4ea"
git-tree-sha1 = "eb3cfba5228aceca0024d9a15086d82ef8330d8e"
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
version = "6.38.4"
version = "6.39.1"

[[DiffEqCallbacks]]
deps = ["DataStructures", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "NLsolve", "OrdinaryDiffEq", "RecipesBase", "RecursiveArrayTools", "StaticArrays"]
Expand All @@ -230,9 +230,9 @@ version = "2.13.3"

[[DiffEqFinancial]]
deps = ["DiffEqBase", "DiffEqNoiseProcess", "LinearAlgebra", "Markdown", "RandomNumbers"]
git-tree-sha1 = "f0c6f2b0b9fa463a90da06142e45ecf8e0b70bac"
git-tree-sha1 = "db08e0def560f204167c58fd0637298e13f58f73"
uuid = "5a0ffddc-d203-54b0-88ba-2c03c0fc2e67"
version = "2.3.0"
version = "2.4.0"

[[DiffEqJump]]
deps = ["ArrayInterface", "Compat", "DataStructures", "DiffEqBase", "FunctionWrappers", "LinearAlgebra", "Parameters", "PoissonRandom", "Random", "RandomNumbers", "RecursiveArrayTools", "StaticArrays", "Statistics", "TreeViews"]
Expand All @@ -242,9 +242,9 @@ version = "6.9.2"

[[DiffEqNoiseProcess]]
deps = ["DataStructures", "DiffEqBase", "Distributions", "LinearAlgebra", "PoissonRandom", "Random", "RandomNumbers", "RecipesBase", "RecursiveArrayTools", "Requires", "ResettableStacks", "StaticArrays", "Statistics"]
git-tree-sha1 = "fc9ba5c47246d1e6c15ae36ce9f5e67b6ffc06b7"
git-tree-sha1 = "474bba439ce886baab756744c54436d7628ef05e"
uuid = "77a26b50-5914-5dd7-bc55-306e6241c503"
version = "4.2.0"
version = "4.3.0"

[[DiffEqPhysics]]
deps = ["DiffEqBase", "DiffEqCallbacks", "ForwardDiff", "LinearAlgebra", "Printf", "Random", "RecipesBase", "RecursiveArrayTools", "Reexport", "StaticArrays"]
Expand Down Expand Up @@ -317,10 +317,10 @@ uuid = "2904ab23-551e-5aed-883f-487f97af5226"
version = "0.2.1"

[[ExponentialUtilities]]
deps = ["LinearAlgebra", "Printf", "SparseArrays"]
git-tree-sha1 = "1672dedeacaab85345fd359ad56dde8fb5d48a45"
deps = ["LinearAlgebra", "Printf", "Requires", "SparseArrays"]
git-tree-sha1 = "91f7498b66205431fe3e35833cda97a22b1ab6a5"
uuid = "d4d017d3-3776-5f7e-afef-a10c40355c18"
version = "1.6.0"
version = "1.7.0"

[[FastGaussQuadrature]]
deps = ["LinearAlgebra", "SpecialFunctions"]
Expand All @@ -330,9 +330,9 @@ version = "0.4.2"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "44f561e293987ffc84272cd3d2b14b0b93123d63"
git-tree-sha1 = "bf726ba7ce99e00d10bf63c031285fb9ab3676ae"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.8.10"
version = "0.8.11"

[[FiniteDiff]]
deps = ["ArrayInterface", "LinearAlgebra", "Requires", "SparseArrays", "StaticArrays"]
Expand Down Expand Up @@ -453,9 +453,9 @@ version = "0.2.0"

[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "72fc0a39d5899091ff2d4cdaa64cb5e4862cf813"
git-tree-sha1 = "d9c6e1efcaa6c2fcd043da812a62b3e489a109a3"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "1.5.2"
version = "1.7.0"

[[LaTeXStrings]]
git-tree-sha1 = "de44b395389b84fd681394d4e8d39ef14e3a2ea8"
Expand All @@ -481,7 +481,6 @@ uuid = "1d6d02ad-be62-4b6b-8a6d-2f90e265016e"
version = "0.1.2"

[[LibGit2]]
deps = ["Printf"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[Libdl]]
Expand Down Expand Up @@ -513,9 +512,9 @@ version = "0.4.1"

[[LoopVectorization]]
deps = ["DocStringExtensions", "LinearAlgebra", "OffsetArrays", "SIMDPirates", "SLEEFPirates", "UnPack", "VectorizationBase"]
git-tree-sha1 = "59f7e9fddaae12967a0c0903aff2d06a8813e2b1"
git-tree-sha1 = "f49302d088dadda9dad58e65883ce24413b8c1f4"
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
version = "0.8.5"
version = "0.8.7"

[[METIS_jll]]
deps = ["Libdl", "Pkg"]
Expand Down Expand Up @@ -546,9 +545,9 @@ version = "1.0.2"

[[MbedTLS_jll]]
deps = ["Libdl", "Pkg"]
git-tree-sha1 = "c83f5a1d038f034ad0549f9ee4d5fac3fb429e33"
git-tree-sha1 = "f85473aeb7a2561a5c58c06c4868971ebe2bcbff"
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
version = "2.16.0+2"
version = "2.16.6+0"

[[Missings]]
deps = ["DataAPI"]
Expand Down Expand Up @@ -653,12 +652,12 @@ version = "0.12.1"

[[Parsers]]
deps = ["Dates", "Test"]
git-tree-sha1 = "eb3e09940c0d7ae01b01d9291ebad7b081c844d3"
git-tree-sha1 = "20ef902ea02f7000756a4bc19f7b9c24867c6211"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "1.0.5"
version = "1.0.6"

[[Pkg]]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Test", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[PoissonRandom]]
Expand Down Expand Up @@ -739,15 +738,15 @@ version = "1.0.1"

[[RecursiveArrayTools]]
deps = ["ArrayInterface", "LinearAlgebra", "RecipesBase", "Requires", "StaticArrays", "Statistics", "ZygoteRules"]
git-tree-sha1 = "96e71928efa701fa5a6df0f88b51f05ceed70f2c"
git-tree-sha1 = "0ffe36b65f0fc4967a42a673c1a9ffa65724dee6"
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
version = "2.4.4"
version = "2.5.0"

[[RecursiveFactorization]]
deps = ["LinearAlgebra", "LoopVectorization"]
git-tree-sha1 = "09217cb106dd826de9960986207175b52e3035f2"
deps = ["LinearAlgebra", "LoopVectorization", "VectorizationBase"]
git-tree-sha1 = "04bc629fc40d612e1a048c61c3fcbbe1adc3b641"
uuid = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
version = "0.1.2"
version = "0.1.3"

[[Reexport]]
deps = ["Pkg"]
Expand Down Expand Up @@ -790,9 +789,9 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[SIMDPirates]]
deps = ["VectorizationBase"]
git-tree-sha1 = "74bf6ed250c21651955bdb36b2b12320374c49ae"
git-tree-sha1 = "18dca6ff298fdde2d5d837f8aaba6d54302ebee3"
uuid = "21efa798-c60a-11e8-04d3-e1a92915a26a"
version = "0.8.7"
version = "0.8.10"

[[SLEEFPirates]]
deps = ["Libdl", "SIMDPirates", "VectorizationBase"]
Expand Down Expand Up @@ -846,9 +845,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[SparseDiffTools]]
deps = ["Adapt", "ArrayInterface", "Compat", "DataStructures", "FiniteDiff", "ForwardDiff", "LightGraphs", "LinearAlgebra", "Requires", "SparseArrays", "VertexSafeGraphs"]
git-tree-sha1 = "bfe68e0d914952932594b3c838f08463b0841037"
git-tree-sha1 = "567fd5758c8271b81cb6497f1bddf1a2d0dd09af"
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
version = "1.8.0"
version = "1.9.0"

[[SpecialFunctions]]
deps = ["BinDeps", "BinaryProvider", "Libdl"]
Expand Down Expand Up @@ -966,15 +965,15 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[Unitful]]
deps = ["ConstructionBase", "LinearAlgebra", "Random"]
git-tree-sha1 = "3714b55de06b11b2aa788b8643d6e91f13648be5"
git-tree-sha1 = "a061dada333813818aa7454f93c63a5cab6ea981"
uuid = "1986cc42-f94f-5a68-af5c-568840ba703d"
version = "1.2.1"
version = "1.3.0"

[[VectorizationBase]]
deps = ["CpuId", "LLVM", "Libdl", "LinearAlgebra"]
git-tree-sha1 = "bcadc352d9c81b0ef9ceebe822d30128b779f56b"
git-tree-sha1 = "ed02d6b61057bb6ddf7e8b1dccfec907cc064b36"
uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
version = "0.12.8"
version = "0.12.13"

[[VersionParsing]]
git-tree-sha1 = "80229be1f670524750d905f8fc8148e5a8c4537f"
Expand All @@ -989,9 +988,9 @@ version = "0.1.2"

[[Zygote]]
deps = ["AbstractFFTs", "ArrayLayouts", "ChainRules", "FillArrays", "ForwardDiff", "Future", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "Random", "Requires", "Statistics", "ZygoteRules"]
git-tree-sha1 = "6d0f78976db6dbea9a36865efe068e6e2a5db6ed"
git-tree-sha1 = "2e2c82549fb0414df10469082fd001e2ede8547c"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.4.21"
version = "0.4.22"

[[ZygoteRules]]
deps = ["MacroTools"]
Expand Down
60 changes: 55 additions & 5 deletions src/EKI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export EKIObj
export construct_initial_ensemble
export compute_error
export update_ensemble!

export find_eki_step

"""
EKIObj{FT<:AbstractFloat, IT<:Int}
Expand All @@ -36,6 +36,8 @@ struct EKIObj{FT<:AbstractFloat, IT<:Int}
g::Vector{Array{FT, 2}}
"vector of errors"
err::Vector{FT}
"vector of timesteps used in each EKI iteration"
Δt::Vector{FT}
end

# outer constructors
Expand All @@ -53,9 +55,11 @@ function EKIObj(parameters::Array{FT, 2},
# observations
g = Vector{FT}[]
# error store
err = []
err = FT[]
# timestep store
Δt = FT[]

EKIObj{FT,IT}(u, parameter_names, t_mean, t_cov, N_ens, g, err)
EKIObj{FT,IT}(u, parameter_names, t_mean, t_cov, N_ens, g, err, Δt)
end


Expand Down Expand Up @@ -91,9 +95,10 @@ function compute_error(eki)
end


function update_ensemble!(eki::EKIObj{FT}, g) where {FT}
function update_ensemble!(eki::EKIObj{FT}, g; cov_threshold::FT=0.01, Δt_new = nothing) where {FT}
# u: N_ens x N_params
u = eki.u[end]
cov_init = cov(eki.u[end], dims=1)

u_bar = fill(FT(0), size(u)[2])
# g: N_ens x N_data
Expand All @@ -102,6 +107,14 @@ function update_ensemble!(eki::EKIObj{FT}, g) where {FT}
cov_ug = fill(FT(0), size(u)[2], size(g)[2])
cov_gg = fill(FT(0), size(g)[2], size(g)[2])

if !isnothing(Δt_new)
push!(eki.Δt, Δt_new)
elseif isnothing(Δt_new) && isempty(eki.Δt)
push!(eki.Δt, FT(1))
else
push!(eki.Δt, eki.Δt[end])
end

# update means/covs with new param/observation pairs u, g
for j = 1:eki.N_ens

Expand All @@ -123,7 +136,7 @@ function update_ensemble!(eki::EKIObj{FT}, g) where {FT}
cov_gg = cov_gg / eki.N_ens - g_bar * g_bar'

# update the parameters (with additive noise too)
noise = rand(MvNormal(zeros(size(g)[2]), eki.cov), eki.N_ens) # N_data * N_ens
noise = rand(MvNormal(zeros(size(g)[2]), eki.cov/eki.Δt[end]), eki.N_ens) # N_data * N_ens
y = (eki.g_t .+ noise)' # add g_t (N_data) to each column of noise (N_data x N_ens), then transp. into N_ens x N_data
tmp = (cov_gg + eki.cov) \ (y - g)' # N_data x N_data \ [N_ens x N_data - N_ens x N_data]' --> tmp is N_data x N_ens
u += (cov_ug * tmp)' # N_ens x N_params
Expand All @@ -134,6 +147,43 @@ function update_ensemble!(eki::EKIObj{FT}, g) where {FT}

compute_error(eki)

# Check convergence
cov_new = cov(eki.u[end], dims=1)
cov_ratio = det(cov_new)/det(cov_init)
if cov_ratio < cov_threshold
@warn string("New ensemble covariance determinant is less than ",cov_threshold," times its former value.
Consider reducing the EKI time step.")
end
end


"""
find_eki_step(eki::EKIObj{FT}, g::Array{FT, 2}; cov_threshold::FT=0.01) where {FT}
Find largest step for the EKI solver that leads to a reduction of the determinant of the sample
covariance matrix no greater than cov_threshold.
"""
function find_eki_step(eki::EKIObj{FT}, g::Array{FT, 2}; cov_threshold::FT=0.01) where {FT}
accept_step = false
if !isempty(eki.Δt)
Δt = deepcopy(eki.Δt[end])
else
Δt = FT(1)
end
# u: N_ens x N_params
cov_init = cov(eki.u[end], dims=1)
while accept_step == false
eki_copy = deepcopy(eki)
update_ensemble!(eki_copy, g, Δt_new=Δt)
cov_new = cov(eki_copy.u[end], dims=1)
if det(cov_new) > cov_threshold*det(cov_init)
accept_step = true
else
Δt = Δt/2
end
end

return Δt

end

end # module EKI
Loading

0 comments on commit a163957

Please sign in to comment.