From 4c222954d8fbf4aba9d72908557acfae85c3a546 Mon Sep 17 00:00:00 2001 From: Emmanuel Lujan Date: Tue, 9 Jul 2024 12:50:06 -0400 Subject: [PATCH] Reduce optimization example by removing random sampler case. --- examples/Opt-ACE-aHfO2/fit-opt-ace-ahfo2.jl | 32 ++++++--------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/examples/Opt-ACE-aHfO2/fit-opt-ace-ahfo2.jl b/examples/Opt-ACE-aHfO2/fit-opt-ace-ahfo2.jl index 3f0a186..7a4f2cb 100644 --- a/examples/Opt-ACE-aHfO2/fit-opt-ace-ahfo2.jl +++ b/examples/Opt-ACE-aHfO2/fit-opt-ace-ahfo2.jl @@ -52,14 +52,16 @@ end; model = ACE pars = OrderedDict( :body_order => [2, 3, 4], :polynomial_degree => [3, 4, 5], - :rcutoff => LinRange(3.5, 6.5, 10), - :wL => LinRange(0.3, 1.8, 10), - :csp => LinRange(0.3, 1.8, 10), - :r0 => LinRange(0.3, 1.8, 10)); + :rcutoff => LinRange(4, 6, 10), + :wL => LinRange(0.5, 1.5, 10), + :csp => LinRange(0.5, 1.5, 10), + :r0 => LinRange(0.5, 1.5, 10)); -# Use **random sampling** to find the optimal hyper-parameters. +# Use **latin hypercube sampling** to find the optimal hyper-parameters. +sampler = CLHSampler(dims=[Categorical(3), Categorical(3), Continuous(), + Continuous(), Continuous(), Continuous()]) iap, res = hyperlearn!(model, pars, conf_train; - n_samples = 10, sampler = RandomSampler(), + n_samples = 10, sampler = sampler, loss = custom_loss, ws = [1.0, 1.0], int = true); # Save and show results. @@ -74,21 +76,5 @@ err_time = plot_err_time(res) @save_fig res_path err_time DisplayAs.PNG(err_time) -# Alternatively, use **latin hypercube sampling** to find the optimal hyper-parameters. -sampler = CLHSampler(dims=[Categorical(3), Categorical(3), Continuous(), - Continuous(), Continuous(), Continuous()]) -iap2, res2 = hyperlearn!(model, pars, conf_train; - n_samples = 10, sampler = sampler, - loss = custom_loss, ws = [1.0, 1.0], int = true); +# Alternatively, use **random sampling** using "sampler = RandomSampler()". -# Save and show results. -@save_var res_path iap2.β -@save_var res_path iap2.β0 -@save_var res_path iap2.basis -@save_dataframe res_path res2 -res2 - -# Plot error vs time. -err_time2 = plot_err_time(res2) -@save_fig res_path err_time2 -DisplayAs.PNG(err_time2)