From 378f542761009269470daa39bcb1f5a7dfffc41b Mon Sep 17 00:00:00 2001 From: Spencer Wyant <17836774+swyant@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:48:03 -0500 Subject: [PATCH] enable normalizing energy weights --- src/Learning/linear-learn.jl | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/Learning/linear-learn.jl b/src/Learning/linear-learn.jl index 5e2252a..bdfce92 100644 --- a/src/Learning/linear-learn.jl +++ b/src/Learning/linear-learn.jl @@ -307,7 +307,8 @@ function ooc_learn!( reg_style::Symbol = :default, AtWA = nothing, AtWb = nothing, - pbar = true + pbar = true, + eweight_normalized = :squared # :squared, :standard or nothing ) basis_size = length(lb.basis) @@ -329,14 +330,26 @@ function ooc_learn!( ref_forces = reduce(vcat,get_values(get_forces(config))) sys = get_system(config) + natoms = length(sys) global_descrs = reshape(sum(compute_local_descriptors(sys,lb.basis)),:,1)' force_descrs = stack(reduce(vcat,compute_force_descriptors(sys,lb.basis)))' A = [global_descrs; force_descrs] b = [ref_energy; ref_forces] if size(W)[1] != size(A)[1] - W = Diagonal( [ws[1]*ones(length(ref_energy)); - ws[2]*ones(length(ref_forces))]) + + if isnothing(eweight_normalized) + we_norm = 1.0 + elseif eweight_normalized == :standard + we_norm = 1/natoms + elseif eweight_normalized == :squared + we_norm = 1/natoms^2 + else + error("eweight_normalized can only be nothing, :standard, or :squared") + end + + W = Diagonal( [we_norm*ws[1]; + ws[2]*ones(length(ref_forces))] ) end AtWA .+= A'*W*A