Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
194: [WIP] Random Feature-based CES r=odunbar a=odunbar <!--- THESE LINES ARE COMMENTED --> ## Purpose <!--- One sentence to describe the purpose of this PR, refer to any linked issues: #14 -- this will link to issue 14 Closes #2 -- this will automatically close issue 2 on PR merge --> Adds the ability to use (Scalar and vector-valued) RF with uncertainty in place of GP within CES. using `RandomFeatures.jl` Closes #164 ## Content <!--- specific tasks that are currently complete - Solution implemented --> - Interfaces with the currently registered RandomFeatures.jl - adds `ScalarRandomFeatureInterface` as a `MachineLearningTool` - adds `VectorRandomFeatureInterface` as a `MachineLearningTool` - new example `examples/Emulator/RandomFeature/optimize_and_plot_RF.jl` - new example `examples/Emulator/RandomFeature/vector_optimize_and_plot_RF.jl` - new example `examples/Lorenz/calibrate.jl` - new example `examples/Lorenz/emulate_sample.jl` The current implementation has Scalar RF replacing (exactly) the GP, whereas Vector RF does no SVD, and therefore learns the output space correlations. The hyperparameter learning is more involved, so to reduce some cost I learn the cholesky factors of an input and output covariance of the feature distribution, currently described by a MatrixVariate Normal distribution. - new example `examples/GCM/emulate_sample_script.jl` though *currently just the emulation!* In this example we have 4 options: (Note in all cases we train on cholesky factors for the input variables) 1. `GPR` trains an `output_dim`-length vector of scalar GPRs, 2. `Scalar RFR SVD` replaces the vector of scalar GPRs, with a vector of scalar RFRs, 3. `Vector RFR SVD Diagonal` assumes a diagonalized output in the vector problem (i.e. still in the setting of a system of Scalar RFs & GPs but only train one object) 4. `Vector RFR SVD nondiagonal` still applies the SVD, but does not assume that the resulting output must be diagonal. It therefore learns cholesky factors of the output 4. `Vector RFR nondiagonal` does not apply SVD, nor assumes the output is diagonal. It learns the cholesky factors of the direct output. ### Emulating an R^2 to R^2 function (150 data points) 1) SVD + Scalar GP (diag in) results <img src="https://user-images.githubusercontent.com/47412152/192404099-be8d1241-2dd4-4263-ba2a-31de94763abb.png" width="300"> <img src="https://user-images.githubusercontent.com/47412152/192404100-62d72ccc-2b36-4ba9-ad36-38bcdf4b9f0f.png" width="300"> 2) SVD + Scalar RF (nondiag in) results <img src="https://user-images.githubusercontent.com/47412152/230235711-6bb0557e-8914-4a43-8f91-f5a144659edc.png" width="300"> <img src="https://user-images.githubusercontent.com/47412152/230235715-54fb7d5e-fa24-4528-a3fc-27ff7e9aceb8.png" width="300"> 3) SVD + vector RF (diag out) results <img src="https://user-images.githubusercontent.com/47412152/230229962-c7eefa25-3a57-467c-8ca1-c9ef7b3dbb3e.png" width="300"> <img src="https://user-images.githubusercontent.com/47412152/230229964-fcfcbe7c-73cd-4837-8db7-5eaa339eaec6.png" width="300"> 4) SVD + vector RF (nondiag out) results <img src="https://user-images.githubusercontent.com/47412152/230230124-bb50e4db-8ba7-4570-930e-b6504936a1b5.png" width="300"> <img src="https://user-images.githubusercontent.com/47412152/230230128-2422e85a-1100-4422-a52c-8fd32183b7f0.png" width="300"> 5) vector RF (diag out) results <img src="https://user-images.githubusercontent.com/47412152/230230033-618dcfa8-99a7-4462-b31f-e9adf302dc14.png" width="300"> <img src="https://user-images.githubusercontent.com/47412152/230230040-7dad6ce8-b5dc-4057-9693-a71e0a72379c.png" width="300"> 6) vector RF (nondiag out) results <img src="https://user-images.githubusercontent.com/47412152/230230167-25377c56-c622-493d-bd0e-b98fc672189a.png" width="300"> <img src="https://user-images.githubusercontent.com/47412152/230230169-23c659eb-4218-4b27-9625-94a7fd44a941.png" width="300"> ### Emulating GCM data R^2 -> R^96, evaluated at a test point #### SVD + Scalar RF results <img src="https://user-images.githubusercontent.com/47412152/219200986-9a5f74e4-5e2a-48cf-8e26-5d66de2e751c.png" width="150"> <img src="https://user-images.githubusercontent.com/47412152/219200996-f5b88c6e-8b51-4df0-acac-187a6a786a78.png" width="150"> <img src="https://user-images.githubusercontent.com/47412152/219201000-afac8c7a-aa9a-4400-93e8-75f560954bba.png" width="150"> #### SVD + Vector RF (restrict to diagonal) results [hparam learnt with 202 features]) <img src="https://user-images.githubusercontent.com/47412152/219200373-7b0e2713-c3db-4891-9012-6852381266b5.png" width="150"> <img src="https://user-images.githubusercontent.com/47412152/219252222-f75ba2c8-b11a-42eb-aa1c-f2308a775041.png" width="150"> <img src="https://user-images.githubusercontent.com/47412152/219200387-d328a6b6-5f64-4eda-bc7e-3a122054c2a5.png" width="150"> #### SVD + vector RF results (full non-diagonal [hparam learnt with 608 features]) <img src="https://user-images.githubusercontent.com/47412152/220444681-20b3ef41-5347-4406-afd6-ffa8cfc1e1b8.png" width="150"> <img src="https://user-images.githubusercontent.com/47412152/220444689-92f24b34-6937-4e19-b733-2d1b263ca9f7.png" width="150"> <img src="https://user-images.githubusercontent.com/47412152/220444685-70f88334-3c32-438e-8534-e4ca0a1c24d2.png" width="150"> #### No-SVD, with vector RF results (full non-diagonal) + standardize each data-type by median <img src="https://user-images.githubusercontent.com/47412152/235567696-4c5665b1-33db-4c83-a554-f003e5e015b6.png" width="150"> <img src="https://user-images.githubusercontent.com/47412152/235567703-0fae36f9-e49b-4cfc-b1b4-16a313af51b8.png" width="150"> <img src="https://user-images.githubusercontent.com/47412152/235567706-1d9513f7-ec37-4fad-b824-a3c62b7759d0.png" width="150"> #### SVD + GP results <img src="https://user-images.githubusercontent.com/47412152/219201341-13acb758-a444-4e05-98c9-ed6975dbd094.png" width="150"> <img src="https://user-images.githubusercontent.com/47412152/219201347-d5c13d9f-3d63-456b-8059-fea0f31346a0.png" width="150"> <img src="https://user-images.githubusercontent.com/47412152/219201351-fa2a4a61-bf75-4292-8549-8f304f05cebe.png" width="150"> ### Full CES test (with "E" emulating an R^2 -> R^12 forward map) 250 data points 1) SVD + Scalar GP (diag in) results 2) SVD + Scalar RF (diag in) results 3) SVD + Scalar RF (nondiag in) results 4) SVD + vector RF (nondiag in, diag out) results 5) SVD + vector RF (nondiag in, nondiag out) results 6) vector RF (nondiag in, diag out) results 7) vector RF (nondiag in, nondiag out) results <img src="https://user-images.githubusercontent.com/47412152/236000320-bbf88ee3-6de7-4e8e-8797-13f48696337a.png" width="175"> <img src="https://user-images.githubusercontent.com/47412152/236364924-c41dd024-e56e-4506-ad19-2fc324d0db61.png" width="175"> <img src="https://user-images.githubusercontent.com/47412152/236364926-f74d9467-7083-4fb9-8c89-b6ea4ab8aad3.png" width="175"> <img src="https://user-images.githubusercontent.com/47412152/236364930-a0424381-ccc5-472c-b85f-e24aec7319f0.png" width="175"> <img src="https://user-images.githubusercontent.com/47412152/236364932-3cb2e6c6-0f5b-4544-9916-cddfcbd3c882.png" width="175"> <img src="https://user-images.githubusercontent.com/47412152/236364927-b05e1bc8-0430-4db8-97dc-29b0fce0f305.png" width="175"> <img src="https://user-images.githubusercontent.com/47412152/236364929-5e78ceb3-07e9-4230-95b1-2e638dab1ce3.png" width="175"> <!--- Review checklist I have: - followed the codebase contribution guide: https://clima.github.io/ClimateMachine.jl/latest/Contributing/ - followed the style guide: https://clima.github.io/ClimateMachine.jl/latest/DevDocs/CodeStyle/ - followed the documentation policy: https://github.com/CliMA/policies/wiki/Documentation-Policy - checked that this PR does not duplicate an open PR. In the Content, I have included - relevant unit tests, and integration tests, - appropriate docstrings on all functions, structs, and modules, and included relevent documentation. --> Co-authored-by: odunbar <[email protected]>
- Loading branch information