Skip to content

Commit

Permalink
Improvements in loss function in hyperparamter optiimization.
Browse files Browse the repository at this point in the history
  • Loading branch information
emmanuellujan committed Jun 27, 2024
1 parent 0883a8d commit b85d784
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 25 deletions.
25 changes: 22 additions & 3 deletions examples/Opt-ACE-aHfO2/fit-opt-ace-ahfo2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,44 @@ conf_train, conf_test = split(ds, n_train, n_test)

# ## c. Hyper-parameter optimization.

# Define the model and hyper-parameter value ranges to be optimized.
# Define loss function. Here, we minimize error and time.
function loss(metrics)
e_mae_max = 0.05
f_mae_max = 0.05
err = metrics[:error] # weighted error: w_e * e_mae + w_f * f_mae
e_mae = metrics[:e_mae]
f_mae = metrics[:f_mae]
time_us = metrics[:time_us]
if e_mae < e_mae_max && f_mae < f_mae_max
loss = time_us
else
loss = time_us + err * 10^3
end
return loss
end

# Define model and hyper-parameter value ranges to be optimized.
model = ACE
pars = OrderedDict( :body_order => [2, 3, 4],
:polynomial_degree => [3, 4, 5],
:rcutoff => [4.5, 5.0, 5.5],
:wL => [0.5, 1.0, 1.5],
:csp => [0.5, 1.0, 1.5],
:r0 => [0.5, 1.0, 1.5]);

# Use random sampling to find the optimal hyper-parameters.
iap, res = hyperlearn!(model, pars, conf_train;
n_samples = 10, sampler = RandomSampler(),
ws = [1.0, 1.0], int = true)

# Save and show results.
@save_var res_path iap.β
@save_var res_path iap.β0
@save_var res_path iap.basis
@save_dataframe res_path res
res

# Plot error vs time
# Plot error vs time.
err_time = plot_err_time(res)
@save_fig res_path err_time
DisplayAs.PNG(err_time)
Expand All @@ -62,13 +80,14 @@ iap, res = hyperlearn!(model, pars, conf_train;
n_samples = 3, sampler = LHSampler(),
ws = [1.0, 1.0], int = true)

# Save and show results.
@save_var res_path iap.β
@save_var res_path iap.β0
@save_var res_path iap.basis
@save_dataframe res_path res
res

# Plot error vs time
# Plot error vs time.
err_time = plot_err_time(res)
@save_fig res_path err_time
DisplayAs.PNG(err_time)
Expand Down
23 changes: 1 addition & 22 deletions examples/Opt-ACE-aHfO2/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,6 @@ function get_results(ho)
end

# Plot fitting error vs force time (Pareto front)
function plot_err_time(res)
error = res[!, :error]
times = res[!, :time_us]
scatter(times,
error,
label = "",
xaxis = "Time per force per atom | µs",
yaxis = "we MAE(E, E') + wf MAE(F, F')")
end

function plot_err_time(res)
e_mae = res[!, :e_mae]
f_mae = res[!, :f_mae]
Expand Down Expand Up @@ -70,17 +60,6 @@ function plot_err_time(res)
end


function loss(p)
err, e_mae, f_mae, time_us = p[1], p[2], p[3], p[4]
e_mae_max, f_mae_max = 0.05, 0.05
if e_mae < e_mae_max && f_mae < f_mae_max
loss = time_us
else
loss = time_us + err * 10^3
end
return loss
end

function get_species(confs)
return unique(vcat(unique.(atomic_symbol.(get_system.(confs)))...))
end
Expand Down Expand Up @@ -126,7 +105,7 @@ function hyperlearn!(model, pars, conf_train;
:f_rsq => f_rsq,
:time_us => time_us)
## Compute multi-objetive loss based on error and time
l = loss([err, e_mae, f_mae, time_us])
l = loss(metrics)
## Print results
print("E_MAE:$(round(e_mae; digits=3)), ")
print("F_MAE:$(round(f_mae; digits=3)), ")
Expand Down

0 comments on commit b85d784

Please sign in to comment.