Skip to content

Commit

Permalink
Improvements in hyperparameter optimization.
Browse files Browse the repository at this point in the history
  • Loading branch information
emmanuellujan committed Jun 27, 2024
1 parent 6e375be commit b4420fc
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 23 deletions.
41 changes: 27 additions & 14 deletions examples/Opt-ACE-aHfO2/fit-opt-ace-ahfo2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,34 +30,47 @@ n_train, n_test = 50, 50 # Only 50 samples per dataset are used in this example.
conf_train, conf_test = split(ds, n_train, n_test)


# ## Hyper-parameter optimization
n_samples = 10
model1 = ACE
pars = OrderedDict( :species => [[:Hf, :O]],
:body_order => [2, 3, 4],
# ## c. Hyper-parameter optimization

# Define the model and hyper-parameter value ranges to be optimized.
model = ACE
pars = OrderedDict( :body_order => [2, 3, 4],
:polynomial_degree => [3, 4, 5],
:rcutoff => [4.5, 5.0, 5.5],
:wL => [0.5, 1.0, 1.5],
:csp => [0.5, 1.0, 1.5],
:r0 => [0.5, 1.0, 1.5])
ws, int = [1.0, 1.0], true
iap, res = hyperlearn!(n_samples, model1, pars, conf_train; ws = ws, int = int)
# Use random sampling to find the optimal hyper-parameters.
iap, res = hyperlearn!(model, pars, conf_train;
n_samples = 10, sampler = RandomSampler(),
ws = [1.0, 1.0], int = true)

@save_var res_path iap.β
@save_var res_path iap.β0
@save_var res_path iap.basis
@save_dataframe res_path res
res

# Plot error vs time
err_time = plot_err_time(res)
@save_fig res_path err_time
DisplayAs.PNG(err_time)

# Post-process output: calculate metrics, create plots, and save results #######

# Prnt and save optimization results
#results = get_results(ho)
@save_dataframe path res
res
# Alternatively, use latin hypercube sampling to find the optimal hyper-parameters.
iap, res = hyperlearn!(model, pars, conf_train;
n_samples = 3, sampler = LHSampler(),
ws = [1.0, 1.0], int = true)

# Optimal IAP
@save_var res_path iap.β
@save_var res_path iap.β0
@save_var res_path iap.basis
@save_dataframe res_path res
res

# Plot error vs time
err_time = plot_err_time(ho)
err_time = plot_err_time(res)
@save_fig res_path err_time
DisplayAs.PNG(err_time)


33 changes: 24 additions & 9 deletions examples/Opt-ACE-aHfO2/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,18 @@ function get_results(ho)
end

# Plot fitting error vs force time (Pareto front)
function plot_err_time(ho)
error = [r[2][:error] for r in ho.results]
times = [r[2][:time_us] for r in ho.results]
function plot_err_time(res)
error = res[!, :error]
times = res[!, :time_us]
scatter(times,
error,
label = "",
xaxis = "Time per force per atom | µs",
yaxis = "we MSE(E, E') + wf MSE(F, F')")
end

function hyper_loss(p)

function loss(p)
err, e_mae, f_mae, time_us = p[1], p[2], p[3], p[4]
e_mae_max, f_mae_max = 0.05, 0.05
if e_mae < e_mae_max && f_mae < f_mae_max
Expand All @@ -48,14 +49,27 @@ function hyper_loss(p)
return loss
end

function get_species(confs)
return unique(vcat(unique.(atomic_symbol.(get_system.(confs)))...))
end

create_ho(x) = Hyperoptimizer(1)

# hyperlearn!
function hyperlearn!(n_samples, model, pars, conf_train;
ws = [1.0, 1.0], int = true, loss = hyper_loss)
function hyperlearn!(model, pars, conf_train;
n_samples = 5, sampler = RandomSampler(), loss = loss,
ws = [1.0, 1.0], int = true)

s = "ho = Hyperoptimizer($n_samples," * join("$k = $v, " for (k, v) in pars) * ")"
s = "create_ho(sampler) = Hyperoptimizer($n_samples, sampler, " *
join("$k = $v, " for (k, v) in pars) * ")"
eval(Meta.parse(s))
ho = Base.invokelatest(create_ho, sampler)
if (ho.sampler isa LHSampler) || (ho.sampler isa CLHSampler)
Hyperopt.init!(ho.sampler, ho)
end
species = get_species(conf_train)
for (i, state...) in ho
basis = model(; state...)
basis = model(; species = species, state...)
iap = LBasisPotential(basis)
## Compute energy and force descriptors
e_descr_new = compute_local_descriptors(conf_train, iap.basis, pbar = false)
Expand All @@ -71,7 +85,8 @@ function hyperlearn!(n_samples, model, pars, conf_train;
f_mae, f_rmse, f_rsq = calc_metrics(f_pred, f)
time_us = estimate_time(conf_train, iap) * 10^6
err = ws[1] * e_rmse^2 + ws[2] * f_rmse^2
metrics = OrderedDict( :e_mae => e_mae,
metrics = OrderedDict( :error => err,
:e_mae => e_mae,
:e_rmse => e_rmse,
:e_rsq => e_rsq,
:f_mae => f_mae,
Expand Down

0 comments on commit b4420fc

Please sign in to comment.