Skip to content

Commit

Permalink
added RF emulator
Browse files Browse the repository at this point in the history
update project toml

examples for RF

shuffle data

update example

example to produce comparable figs

updates for compatability with CES 0.2.0 and RF 0.1.0

vector random feature support

added fixes to ensure CES pipeline runs

regularization and lorenz example

format

allows training with fewer features than data

VRFI with SVD, and cholesky options

feature num dep on n

GCM example

replace data

multithreading and rng

add ProgressBars

remove high-level threading for now (takes place within LinAlg solvers)

sbatch script

truth at some points

increased number optimization features default

bugfix reg matrix argument

initial tik-reg for EKI

working TEKI

0 default eki, small tweaks

add logdet complexity

more consistent adding of definiteness

chol/svd

add logdet to scalar learning

shape bug

logdetI

unite common functions in Random Feature, expand Scalar feature learning

extend reverse svd for covs

add diag terms to MatrixNormal description, default to diagonal regularizations rather than pos-def

add diagonal option

trimmed, and added const hp for diag cov

compat with svd truncation, and more standard posdef corrections

added scaling to complexity data

change scalar interface

lorenz 2d statsplot

combine all MLT examples into this

improved interfacing, unification and initial unit testing

condensed into emulate_sample

simplify scalar interface

bug

improved vector interface

reg should be multiplicative! fixed

small edits

update ess.jl

MSE on next ensemble, add input-diag case

inflation

optimizer defaults and cov representation

inflation vec

inflation

utility for ensembles

test pass with new defaults and cov structure

format

format

add RF tests

GP test fails resolved

scalar_optimize_and_plot_RF.jl

with new RF accel

removed some abstract types, compatible with RandomFeatures 0.2.5

format

typo

another typo

compatible with v0.3 RandomFeatures

dispatching over RandomFeatures v0.3.1 multthread options

typo

more flexible priors

compataility for SRF and RF v0.3.1

updates to GCM example scripts

docstrings

docstring API and format

rm duplicate API docs

format

API docs work locally

rename

test pass

format

better messages, bugfix priors for diagonalized options

emulation test scenarios for scalar and vector RF

multithread supp for lorenz example

added cov samples user option, added opt-option for threading in prediction

test

format

Lorenz example config

verbose flag

tests pass

format

test for tullio threading

format
  • Loading branch information
odunbar committed May 5, 2023
1 parent 045ee4a commit defb3ee
Show file tree
Hide file tree
Showing 31 changed files with 3,282 additions and 246 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RandomFeatures = "36c3bae2-c0c3-419d-b3b4-eebadd35c5e5"
ScikitLearn = "3646fa90-6ef7-5e7e-9f22-8aca16db6324"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

Expand All @@ -32,8 +35,9 @@ GaussianProcesses = "0.12"
MCMCChains = "4.14, 5, 6"
PyCall = "1.93"
ScikitLearn = "0.6, 0.7"
RandomFeatures = "0.3"
StatsBase = "0.33"
julia = "1.6"
julia = "1.6, 1.7, 1.8"

[extras]
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Expand Down
1 change: 0 additions & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
[deps]
CalibrateEmulateSample = "95e48a1f-0bec-4818-9538-3db4340308e3"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"

6 changes: 5 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ design = ["AbstractMCMC sampling API" => "API/AbstractMCMC.md"]

api = [
"CalibrateEmulateSample" => [
"Emulators" => ["General Emulator" => "API/Emulators.md", "Gaussian Process" => "API/GaussianProcess.md"],
"Emulators" => [
"General Interface" => "API/Emulators.md",
"Gaussian Process" => "API/GaussianProcess.md",
"Random Features" => "API/RandomFeatures.md",
],
"MarkovChainMonteCarlo" => "API/MarkovChainMonteCarlo.md",
"Utilities" => "API/Utilities.md",
],
Expand Down
3 changes: 2 additions & 1 deletion docs/src/API/GaussianProcess.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ CurrentModule = CalibrateEmulateSample.Emulators
GaussianProcessesPackage
PredictionType
GaussianProcess
build_models!
build_models!(::GaussianProcess{GPJL}, ::PairedDataContainer{FT}) where {FT <: AbstractFloat}
optimize_hyperparameters!(::GaussianProcess{GPJL})
predict(::GaussianProcess{GPJL}, ::AbstractMatrix{FT}) where {FT <: AbstractFloat}
```
39 changes: 39 additions & 0 deletions docs/src/API/RandomFeatures.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# RandomFeatures

```@meta
CurrentModule = CalibrateEmulateSample.Emulators
```

## Scalar interface

```@docs
ScalarRandomFeatureInterface
ScalarRandomFeatureInterface(::Int,::Int)
build_models!(::ScalarRandomFeatureInterface, ::PairedDataContainer{FT}) where {FT <: AbstractFloat}
predict(::ScalarRandomFeatureInterface, ::M) where {M <: AbstractMatrix}
```

## Vector Interface

```@docs
VectorRandomFeatureInterface
VectorRandomFeatureInterface(::Int, ::Int, ::Int)
build_models!(::VectorRandomFeatureInterface, ::PairedDataContainer{FT}) where {FT <: AbstractFloat}
predict(::VectorRandomFeatureInterface, ::M) where {M <: AbstractMatrix}
```

## Other utilities
```@docs
get_rfms
get_fitted_features
get_batch_sizes
get_n_features
get_input_dim
get_output_dim
get_rng
get_diagonalize_input
get_feature_decomposition
get_optimizer_options
optimize_hyperparameters!(::ScalarRandomFeatureInterface)
optimize_hyperparameters!(::VectorRandomFeatureInterface)
```
12 changes: 6 additions & 6 deletions examples/Emulator/GaussianProcess/plot_GP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ if !isdir(output_directory)
end

#create the machine learning tools: Gaussian Process
gppackage = GPJL()
gppackage = SKLJL()
pred_type = YType()
gaussian_process = GaussianProcess(gppackage, noise_learn = true)
gaussian_process = GaussianProcess(gppackage, noise_learn = false)

# Generate training data (x-y pairs, where x ∈ ℝ ᵖ, y ∈ ℝ ᵈ)
# x = [x1, x2]: inputs/predictors/features/parameters
Expand All @@ -92,7 +92,7 @@ gx[2, :] = g2x

# Add noise η
μ = zeros(d)
Σ = 0.1 * [[0.8, 0.0] [0.0, 0.5]] # d x d
Σ = 0.1 * [[0.8, 0.1] [0.1, 0.5]] # d x d
noise_samples = rand(MvNormal(μ, Σ), n)
# y = G(x) + η
Y = gx .+ noise_samples
Expand Down Expand Up @@ -182,9 +182,9 @@ println("GP trained")

# Plot mean and variance of the predicted observables y1 and y2
# For this, we generate test points on a x1-x2 grid.
n_pts = 50
x1 = range(0.0, stop = 2 * π, length = n_pts)
x2 = range(0.0, stop = 2 * π, length = n_pts)
n_pts = 200
x1 = range(0.0, stop = (4.0 / 5.0) * 2 * π, length = n_pts)
x2 = range(0.0, stop = (4.0 / 5.0) * 2 * π, length = n_pts)
X1, X2 = meshgrid(x1, x2)
# Input for predict has to be of size N_samples x input_dim
inputs = permutedims(hcat(X1[:], X2[:]), (2, 1))
Expand Down
15 changes: 15 additions & 0 deletions examples/Emulator/RandomFeature/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[deps]
CalibrateEmulateSample = "95e48a1f-0bec-4818-9538-3db4340308e3"
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
FiniteDiff = "~2.10"
julia = "~1.6"
Loading

0 comments on commit defb3ee

Please sign in to comment.