From e76a16389ba465232d346dc05999c38299bf89f4 Mon Sep 17 00:00:00 2001 From: Spencer Wyant <17836774+swyant@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:07:26 -0500 Subject: [PATCH 1/2] adding ProgressBar to ooc_learn --- src/Learning/linear-learn.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/Learning/linear-learn.jl b/src/Learning/linear-learn.jl index 081658b..5e2252a 100644 --- a/src/Learning/linear-learn.jl +++ b/src/Learning/linear-learn.jl @@ -306,7 +306,8 @@ function ooc_learn!( λ::Union{Real,Nothing} = 0.01, reg_style::Symbol = :default, AtWA = nothing, - AtWb = nothing + AtWb = nothing, + pbar = true ) basis_size = length(lb.basis) @@ -317,9 +318,13 @@ function ooc_learn!( W = zeros(1,1) - configs = get_system.(ds_train) + if pbar + iter = ProgressBar(ds_train) + else + iter = ds_train + end - for config in ds_train + for config in iter ref_energy = get_values(get_energy(config)) ref_forces = reduce(vcat,get_values(get_forces(config))) From 378f542761009269470daa39bcb1f5a7dfffc41b Mon Sep 17 00:00:00 2001 From: Spencer Wyant <17836774+swyant@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:48:03 -0500 Subject: [PATCH 2/2] enable normalizing energy weights --- src/Learning/linear-learn.jl | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/Learning/linear-learn.jl b/src/Learning/linear-learn.jl index 5e2252a..bdfce92 100644 --- a/src/Learning/linear-learn.jl +++ b/src/Learning/linear-learn.jl @@ -307,7 +307,8 @@ function ooc_learn!( reg_style::Symbol = :default, AtWA = nothing, AtWb = nothing, - pbar = true + pbar = true, + eweight_normalized = :squared # :squared, :standard or nothing ) basis_size = length(lb.basis) @@ -329,14 +330,26 @@ function ooc_learn!( ref_forces = reduce(vcat,get_values(get_forces(config))) sys = get_system(config) + natoms = length(sys) global_descrs = reshape(sum(compute_local_descriptors(sys,lb.basis)),:,1)' force_descrs = stack(reduce(vcat,compute_force_descriptors(sys,lb.basis)))' A = [global_descrs; force_descrs] b = [ref_energy; ref_forces] if size(W)[1] != size(A)[1] - W = Diagonal( [ws[1]*ones(length(ref_energy)); - ws[2]*ones(length(ref_forces))]) + + if isnothing(eweight_normalized) + we_norm = 1.0 + elseif eweight_normalized == :standard + we_norm = 1/natoms + elseif eweight_normalized == :squared + we_norm = 1/natoms^2 + else + error("eweight_normalized can only be nothing, :standard, or :squared") + end + + W = Diagonal( [we_norm*ws[1]; + ws[2]*ones(length(ref_forces))] ) end AtWA .+= A'*W*A