diff --git a/examples/Opt-ACE-aHfO2/fit-opt-ace-ahfo2.jl b/examples/Opt-ACE-aHfO2/fit-opt-ace-ahfo2.jl index c759aae2..88da3b5b 100644 --- a/examples/Opt-ACE-aHfO2/fit-opt-ace-ahfo2.jl +++ b/examples/Opt-ACE-aHfO2/fit-opt-ace-ahfo2.jl @@ -30,34 +30,47 @@ n_train, n_test = 50, 50 # Only 50 samples per dataset are used in this example. conf_train, conf_test = split(ds, n_train, n_test) -# ## Hyper-parameter optimization -n_samples = 10 -model1 = ACE -pars = OrderedDict( :species => [[:Hf, :O]], - :body_order => [2, 3, 4], +# ## c. Hyper-parameter optimization + +# Define the model and hyper-parameter value ranges to be optimized. +model = ACE +pars = OrderedDict( :body_order => [2, 3, 4], :polynomial_degree => [3, 4, 5], :rcutoff => [4.5, 5.0, 5.5], :wL => [0.5, 1.0, 1.5], :csp => [0.5, 1.0, 1.5], :r0 => [0.5, 1.0, 1.5]) -ws, int = [1.0, 1.0], true -iap, res = hyperlearn!(n_samples, model1, pars, conf_train; ws = ws, int = int) +# Use random sampling to find the optimal hyper-parameters. +iap, res = hyperlearn!(model, pars, conf_train; + n_samples = 10, sampler = RandomSampler(), + ws = [1.0, 1.0], int = true) + +@save_var res_path iap.β +@save_var res_path iap.β0 +@save_var res_path iap.basis +@save_dataframe res_path res +res +# Plot error vs time +err_time = plot_err_time(res) +@save_fig res_path err_time +DisplayAs.PNG(err_time) -# Post-process output: calculate metrics, create plots, and save results ####### -# Prnt and save optimization results -#results = get_results(ho) -@save_dataframe path res -res +# Alternatively, use latin hypercube sampling to find the optimal hyper-parameters. +iap, res = hyperlearn!(model, pars, conf_train; + n_samples = 3, sampler = LHSampler(), + ws = [1.0, 1.0], int = true) -# Optimal IAP @save_var res_path iap.β @save_var res_path iap.β0 @save_var res_path iap.basis +@save_dataframe res_path res +res # Plot error vs time -err_time = plot_err_time(ho) +err_time = plot_err_time(res) @save_fig res_path err_time DisplayAs.PNG(err_time) + diff --git a/examples/Opt-ACE-aHfO2/utils.jl b/examples/Opt-ACE-aHfO2/utils.jl index 27177d51..6012d56b 100644 --- a/examples/Opt-ACE-aHfO2/utils.jl +++ b/examples/Opt-ACE-aHfO2/utils.jl @@ -27,9 +27,9 @@ function get_results(ho) end # Plot fitting error vs force time (Pareto front) -function plot_err_time(ho) - error = [r[2][:error] for r in ho.results] - times = [r[2][:time_us] for r in ho.results] +function plot_err_time(res) + error = res[!, :error] + times = res[!, :time_us] scatter(times, error, label = "", @@ -37,7 +37,8 @@ function plot_err_time(ho) yaxis = "we MSE(E, E') + wf MSE(F, F')") end -function hyper_loss(p) + +function loss(p) err, e_mae, f_mae, time_us = p[1], p[2], p[3], p[4] e_mae_max, f_mae_max = 0.05, 0.05 if e_mae < e_mae_max && f_mae < f_mae_max @@ -48,14 +49,27 @@ function hyper_loss(p) return loss end +function get_species(confs) + return unique(vcat(unique.(atomic_symbol.(get_system.(confs)))...)) +end + +create_ho(x) = Hyperoptimizer(1) + # hyperlearn! -function hyperlearn!(n_samples, model, pars, conf_train; - ws = [1.0, 1.0], int = true, loss = hyper_loss) +function hyperlearn!(model, pars, conf_train; + n_samples = 5, sampler = RandomSampler(), loss = loss, + ws = [1.0, 1.0], int = true) - s = "ho = Hyperoptimizer($n_samples," * join("$k = $v, " for (k, v) in pars) * ")" + s = "create_ho(sampler) = Hyperoptimizer($n_samples, sampler, " * + join("$k = $v, " for (k, v) in pars) * ")" eval(Meta.parse(s)) + ho = Base.invokelatest(create_ho, sampler) + if (ho.sampler isa LHSampler) || (ho.sampler isa CLHSampler) + Hyperopt.init!(ho.sampler, ho) + end + species = get_species(conf_train) for (i, state...) in ho - basis = model(; state...) + basis = model(; species = species, state...) iap = LBasisPotential(basis) ## Compute energy and force descriptors e_descr_new = compute_local_descriptors(conf_train, iap.basis, pbar = false) @@ -71,7 +85,8 @@ function hyperlearn!(n_samples, model, pars, conf_train; f_mae, f_rmse, f_rsq = calc_metrics(f_pred, f) time_us = estimate_time(conf_train, iap) * 10^6 err = ws[1] * e_rmse^2 + ws[2] * f_rmse^2 - metrics = OrderedDict( :e_mae => e_mae, + metrics = OrderedDict( :error => err, + :e_mae => e_mae, :e_rmse => e_rmse, :e_rsq => e_rsq, :f_mae => f_mae,