Skip to content

Commit

Permalink
add diags and plotting for it in examples, VERY HACKY
Browse files Browse the repository at this point in the history
updates from hpc

rm prints

repeats for uq_for_edmf

add save jld2 and lines+series depending on repeats

add jld2 and log-scale

rm prints, typos

format
  • Loading branch information
odunbar committed Jul 11, 2024
1 parent 4b6d293 commit 842ba2d
Show file tree
Hide file tree
Showing 8 changed files with 223 additions and 62 deletions.
13 changes: 9 additions & 4 deletions examples/EDMF_data/plot_posterior.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ using CalibrateEmulateSample.ParameterDistributions
# date = Date(year,month,day)

# 2-parameter calibration exp
exp_name = "ent-det-calibration"
date_of_run = Date(2023, 10, 5)
#exp_name = "ent-det-calibration"
#date_of_run = Date(2023, 10, 17)

# 5-parameter calibration exp
#exp_name = "ent-det-tked-tkee-stab-calibration"
#date_of_run = Date(2023,10,4)
exp_name = "ent-det-tked-tkee-stab-calibration"
date_of_run = Date(2024, 2, 2)

# Output figure read/write directory
figure_save_directory = joinpath(@__DIR__, "output", exp_name, string(date_of_run))
Expand Down Expand Up @@ -50,3 +50,8 @@ p = pairplot(data => (PairPlots.Scatter(),))
trans_p = pairplot(transformed_data => (PairPlots.Scatter(),))
save(density_filepath, p)
save(transformed_density_filepath, trans_p)

density_filepath = joinpath(figure_save_directory, "posterior_dist_comp.pdf")
transformed_density_filepath = joinpath(figure_save_directory, "posterior_dist_phys.pdf")
save(density_filepath, p)
save(transformed_density_filepath, trans_p)
148 changes: 106 additions & 42 deletions examples/EDMF_data/uq_for_edmf.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#include(joinpath(@__DIR__, "..", "ci", "linkfig.jl"))
#includef(joinpath(@__DIR__, "..", "ci", "linkfig.jl"))
PLOT_FLAG = false

# Import modules
using Distributions # probability distributions and associated functions
using LinearAlgebra
ENV["GKSwstype"] = "100"
using Plots
using CairoMakie
using Random
using JLD2
using NCDatasets
Expand All @@ -28,10 +29,10 @@ Random.seed!(rng_seed)
function main()

# 2-parameter calibration exp
exp_name = "ent-det-calibration"
#exp_name = "ent-det-calibration"

# 5-parameter calibration exp
#exp_name = "ent-det-tked-tkee-stab-calibration"
exp_name = "ent-det-tked-tkee-stab-calibration"


# Output figure save directory
Expand Down Expand Up @@ -120,6 +121,7 @@ function main()
for plot_i in 1:size(outputs, 1)
p = scatter(inputs_constrained[1, :], inputs_constrained[2, :], zcolor = outputs[plot_i, :])
savefig(p, joinpath(figure_save_directory, "output_" * string(plot_i) * ".png"))
savefig(p, joinpath(figure_save_directory, "output_" * string(plot_i) * ".pdf"))
end
println("finished plotting ensembles.")
end
Expand Down Expand Up @@ -201,52 +203,114 @@ function main()
cases = [
"GP", # diagonalize, train scalar GP, assume diag inputs
"RF-vector-svd-nonsep",
"RF-vector-nosvd-nonsep", # don't perform decorrelation
]
case = cases[2]

overrides = Dict(
"verbose" => true,
"train_fraction" => 0.95,
"scheduler" => DataMisfitController(terminate_at = 100),
"cov_sample_multiplier" => 0.5,
"n_iteration" => 5,
)
nugget = 0.01
rng_seed = 99330
rng = Random.MersenneTwister(rng_seed)
input_dim = size(get_inputs(input_output_pairs), 1)
output_dim = size(get_outputs(input_output_pairs), 1)
if case == "GP"

gppackage = Emulators.SKLJL()
pred_type = Emulators.YType()
mlt = GaussianProcess(
gppackage;
kernel = nothing, # use default squared exponential kernel
prediction_type = pred_type,
noise_learn = false,
case = cases[3]
n_repeats = 2

opt_diagnostics = []
emulators = []
for rep_idx in 1:n_repeats

overrides = Dict(
"verbose" => true,
"train_fraction" => 0.9, #95
"scheduler" => DataMisfitController(terminate_at = 1e5),
"cov_sample_multiplier" => 0.4,
"n_features_opt" => 200,
"n_iteration" => 15,
# "n_ensemble" => 20,
# "localization" => SEC(1.0, 0.01), # localization / sample error correction for small ensembles
)
elseif case ["RF-vector-svd-nonsep"]
kernel_structure = NonseparableKernel(LowRankFactor(3, nugget))
n_features = 500

mlt = VectorRandomFeatureInterface(
n_features,
input_dim,
output_dim,
rng = rng,
kernel_structure = kernel_structure,
optimizer_options = overrides,
nugget = 1e-10#1e-12#0.01
rng_seed = 99330
rng = Random.MersenneTwister(rng_seed)
input_dim = size(get_inputs(input_output_pairs), 1)
output_dim = size(get_outputs(input_output_pairs), 1)
decorrelate = true
if case == "GP"

gppackage = Emulators.SKLJL()
pred_type = Emulators.YType()
mlt = GaussianProcess(
gppackage;
kernel = nothing, # use default squared exponential kernel
prediction_type = pred_type,
noise_learn = false,
)
elseif case ["RF-vector-svd-nonsep"]
kernel_structure = NonseparableKernel(LowRankFactor(3, nugget))
n_features = 500

mlt = VectorRandomFeatureInterface(
n_features,
input_dim,
output_dim,
rng = rng,
kernel_structure = kernel_structure,
optimizer_options = overrides,
)
elseif case ["RF-vector-nosvd-nonsep"]
kernel_structure = NonseparableKernel(LowRankFactor(3, nugget))
n_features = 500

mlt = VectorRandomFeatureInterface(
n_features,
input_dim,
output_dim,
rng = rng,
kernel_structure = kernel_structure,
optimizer_options = overrides,
)
decorrelate = false
end

# Fit an emulator to the data
normalized = true

emulator = Emulator(
mlt,
input_output_pairs;
obs_noise_cov = truth_cov,
normalize_inputs = normalized,
decorrelate = decorrelate,
)

# Optimize the GP hyperparameters for better fit
optimize_hyperparameters!(emulator)
if case ["RF-vector-nosvd-nonsep", "RF-vector-svd-nonsep"]
push!(opt_diagnostics, get_optimizer(mlt)[1]) #length-1 vec of vec -> vec
end

for rep_idx in n_repeats
push!(emulators, emulator)
end
end
emulator = emulators[1]

# Fit an emulator to the data
normalized = true
# plot eki convergence plot
if length(opt_diagnostics) > 0
err_cols = reduce(hcat, opt_diagnostics) #error for each repeat as columns?

emulator = Emulator(mlt, input_output_pairs; obs_noise_cov = truth_cov, normalize_inputs = normalized)
#save data
error_filepath = joinpath(data_save_directory, "eki_conv_error.jld2")
save(error_filepath, "error", err_cols)

# Optimize the GP hyperparameters for better fit
optimize_hyperparameters!(emulator)
# print all repeats
f5 = Figure(resolution = (1.618 * 300, 300), markersize = 4)
ax_conv = Axis(f5[1, 1], xlabel = "Iteration", ylabel = "max-normalized error")
if n_repeats == 1
lines!(ax_conv, collect(1:size(err_cols, 1))[:], err_cols[:], solid_color = :blue) # If just one repeat
else
for idx in 1:size(err_cols, 1)
err_normalized = (err_cols' ./ err_cols[1, :])' # divide each series by the max, so all errors start at 1
series!(ax_conv, err_normalized', solid_color = :blue)
end
end
save(joinpath(figure_save_directory, "eki-conv_$(case).png"), f5, px_per_unit = 3)
save(joinpath(figure_save_directory, "eki-conv_$(case).pdf"), f5, px_per_unit = 3)

end

emulator_filepath = joinpath(data_save_directory, "emulator.jld2")
save(emulator_filepath, "emulator", emulator)
Expand Down
1 change: 1 addition & 0 deletions examples/Emulator/Ishigami/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GlobalSensitivityAnalysis = "1b10255b-6da3-57ce-9089-d24e8517b87e"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
42 changes: 38 additions & 4 deletions examples/Emulator/Ishigami/emulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using LinearAlgebra
using CalibrateEmulateSample.EnsembleKalmanProcesses
using CalibrateEmulateSample.Emulators
using CalibrateEmulateSample.DataContainers
using CalibrateEmulateSample.EnsembleKalmanProcesses.Localizers

using CairoMakie, ColorSchemes #for plots
seed = 2589456
Expand Down Expand Up @@ -81,9 +82,13 @@ function main()
case = cases[3]
decorrelate = true
nugget = Float64(1e-12)

overrides =
Dict("verbose" => true, "scheduler" => DataMisfitController(terminate_at = 1e4), "n_features_opt" => 200)
overrides = Dict(
"scheduler" => DataMisfitController(terminate_at = 1e4),
"n_features_opt" => 150,
"n_ensemble" => 30,
"n_iteration" => 20,
"accelerator" => NesterovAccelerator(),
)
if case == "Prior"
# don't do anything
overrides["n_iteration"] = 0
Expand All @@ -92,7 +97,7 @@ function main()

y_preds = []
result_preds = []

opt_diagnostics = []
for rep_idx in 1:n_repeats

# Build ML tools
Expand All @@ -118,6 +123,11 @@ function main()
emulator = Emulator(mlt, iopairs; obs_noise_cov = Γ * I, decorrelate = decorrelate)
optimize_hyperparameters!(emulator)

# get EKP errors - just stored in "optimizer" box for now
if case == "RF-scalar"
diag_tmp = reduce(hcat, get_optimizer(mlt)) # (n_iteration, dim_output=1) convergence for each scalar mode as cols
push!(opt_diagnostics, diag_tmp)
end
# predict on all Sobol points with emulator (example)
y_pred, y_var = predict(emulator, samples', transform_to_real = true)

Expand Down Expand Up @@ -186,6 +196,30 @@ function main()
save(joinpath(output_directory, "ishigami_slices_$(case).pdf"), f2, px_per_unit = 3)


if length(opt_diagnostics) > 0
err_cols = reduce(hcat, opt_diagnostics) #error for each repeat as columns?

#save
error_filepath = joinpath(output_directory, "eki_conv_error.jld2")
save(error_filepath, "error", err_cols)

# print all repeats
f3 = Figure(resolution = (1.618 * 300, 300), markersize = 4)
ax_conv = Axis(f3[1, 1], xlabel = "Iteration", ylabel = "Error")

if n_repeats == 1
lines!(ax_conv, collect(1:size(err_cols, 1))[:], err_cols[:], solid_color = :blue) # If just one repeat
else
for idx in 1:size(err_cols, 1)
err_normalized = (err_cols' ./ err_cols[1, :])' # divide each series by the max, so all errors start at 1
series!(ax_conv, err_normalized', solid_color = :blue)
end
end

save(joinpath(output_directory, "ishigami_eki-conv_$(case).png"), f3, px_per_unit = 3)
save(joinpath(output_directory, "ishigami_eki-conv_$(case).pdf"), f3, px_per_unit = 3)

end
end


Expand Down
45 changes: 37 additions & 8 deletions examples/Emulator/L63/emulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function main()
end

# rng
rng = MersenneTwister(1232434)
rng = MersenneTwister(1232435)

n_repeats = 20 # repeat exp with same data.
println("run experiment $n_repeats times")
Expand Down Expand Up @@ -92,20 +92,22 @@ function main()
# Emulate
cases = ["GP", "RF-scalar", "RF-scalar-diagin", "RF-svd-nonsep", "RF-nosvd-nonsep", "RF-nosvd-sep"]

case = cases[1]
case = cases[5]

nugget = Float64(1e-12)
u_test = []
u_hist = []
train_err = []
opt_diagnostics = []

for rep_idx in 1:n_repeats

rf_optimizer_overrides = Dict(
"scheduler" => DataMisfitController(terminate_at = 1e4),
"cov_sample_multiplier" => 0.5,
"n_features_opt" => 400,
"n_iteration" => 30,
"accelerator" => ConstantStepNesterovAccelerator(),
"cov_sample_multiplier" => 1.0,
"n_features_opt" => 200,
"n_iteration" => 10, #30
"accelerator" => NesterovAccelerator(),
)

# Build ML tools
Expand Down Expand Up @@ -170,6 +172,11 @@ function main()
emulator = Emulator(mlt, iopairs; obs_noise_cov = Γy, decorrelate = decorrelate)
optimize_hyperparameters!(emulator)

# diagnostics
if case == "RF-nosvd-nonsep"
push!(opt_diagnostics, get_optimizer(mlt)[1]) #length-1 vec of vec -> vec
end


# Predict with emulator
u_test_tmp = zeros(3, length(xspan_test))
Expand Down Expand Up @@ -252,6 +259,30 @@ function main()
JLD2.save(joinpath(output_directory, case * "_l63_histdata.jld2"), "solhist", solhist, "uhist", u_hist)
JLD2.save(joinpath(output_directory, case * "_l63_testdata.jld2"), "solplot", solplot, "uplot", u_test)

# plot eki convergence plot
if length(opt_diagnostics) > 0
err_cols = reduce(hcat, opt_diagnostics) #error for each repeat as columns?

#save
error_filepath = joinpath(output_directory, "eki_conv_error.jld2")
save(error_filepath, "error", err_cols)

# print all repeats
f5 = Figure(resolution = (1.618 * 300, 300), markersize = 4)
ax_conv = Axis(f5[1, 1], xlabel = "Iteration", ylabel = "max-normalized error", yscale = log10)
if n_repeats == 1
lines!(ax_conv, collect(1:size(err_cols, 1))[:], err_cols[:], solid_color = :blue) # If just one repeat
else
for idx in 1:size(err_cols, 1)
err_normalized = (err_cols' ./ err_cols[1, :])' # divide each series by the max, so all errors start at 1
series!(ax_conv, err_normalized', solid_color = :blue)
end
end
save(joinpath(output_directory, "l63_eki-conv_$(case).png"), f5, px_per_unit = 3)
save(joinpath(output_directory, "l63_eki-conv_$(case).pdf"), f5, px_per_unit = 3)

end

# compare marginal histograms to truth - rough measure of fit
sol_cdf = sort(solhist, dims = 2)

Expand All @@ -278,8 +309,6 @@ function main()
lines!(axy, sol_cdf[2, :], unif_samples, color = (:orange, 1.0), linewidth = 4)
lines!(axz, sol_cdf[3, :], unif_samples, color = (:orange, 1.0), linewidth = 4)



# save
save(joinpath(output_directory, case * "_l63_cdfs.png"), f4, px_per_unit = 3)
save(joinpath(output_directory, case * "_l63_cdfs.pdf"), f4, pt_per_unit = 3)
Expand Down
1 change: 0 additions & 1 deletion src/MarkovChainMonteCarlo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ function AbstractMCMC.bundle_samples(
)
# Turn all the transitions into a vector-of-vectors.
vals = [vcat(t.params, t.log_density, t.accepted) for t in ts]

# Check if we received any parameter names.
if ismissing(param_names)
param_names = [Symbol(:param_, i) for i in 1:length(keys(ts[1].params))]
Expand Down
Loading

0 comments on commit 842ba2d

Please sign in to comment.