Skip to content

Commit

Permalink
enable normalizing energy weights
Browse files Browse the repository at this point in the history
  • Loading branch information
swyant committed Nov 14, 2024
1 parent e76a163 commit 378f542
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions src/Learning/linear-learn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ function ooc_learn!(
reg_style::Symbol = :default,
AtWA = nothing,
AtWb = nothing,
pbar = true
pbar = true,
eweight_normalized = :squared # :squared, :standard or nothing
)

basis_size = length(lb.basis)
Expand All @@ -329,14 +330,26 @@ function ooc_learn!(
ref_forces = reduce(vcat,get_values(get_forces(config)))

sys = get_system(config)
natoms = length(sys)
global_descrs = reshape(sum(compute_local_descriptors(sys,lb.basis)),:,1)'
force_descrs = stack(reduce(vcat,compute_force_descriptors(sys,lb.basis)))'

A = [global_descrs; force_descrs]
b = [ref_energy; ref_forces]
if size(W)[1] != size(A)[1]
W = Diagonal( [ws[1]*ones(length(ref_energy));
ws[2]*ones(length(ref_forces))])

if isnothing(eweight_normalized)
we_norm = 1.0
elseif eweight_normalized == :standard
we_norm = 1/natoms
elseif eweight_normalized == :squared
we_norm = 1/natoms^2
else
error("eweight_normalized can only be nothing, :standard, or :squared")
end

W = Diagonal( [we_norm*ws[1];
ws[2]*ones(length(ref_forces))] )
end

AtWA .+= A'*W*A
Expand Down

0 comments on commit 378f542

Please sign in to comment.