diff --git a/src/Learning/linear-learn.jl b/src/Learning/linear-learn.jl index 081658b..bdfce92 100644 --- a/src/Learning/linear-learn.jl +++ b/src/Learning/linear-learn.jl @@ -306,7 +306,9 @@ function ooc_learn!( λ::Union{Real,Nothing} = 0.01, reg_style::Symbol = :default, AtWA = nothing, - AtWb = nothing + AtWb = nothing, + pbar = true, + eweight_normalized = :squared # :squared, :standard or nothing ) basis_size = length(lb.basis) @@ -317,21 +319,37 @@ function ooc_learn!( W = zeros(1,1) - configs = get_system.(ds_train) + if pbar + iter = ProgressBar(ds_train) + else + iter = ds_train + end - for config in ds_train + for config in iter ref_energy = get_values(get_energy(config)) ref_forces = reduce(vcat,get_values(get_forces(config))) sys = get_system(config) + natoms = length(sys) global_descrs = reshape(sum(compute_local_descriptors(sys,lb.basis)),:,1)' force_descrs = stack(reduce(vcat,compute_force_descriptors(sys,lb.basis)))' A = [global_descrs; force_descrs] b = [ref_energy; ref_forces] if size(W)[1] != size(A)[1] - W = Diagonal( [ws[1]*ones(length(ref_energy)); - ws[2]*ones(length(ref_forces))]) + + if isnothing(eweight_normalized) + we_norm = 1.0 + elseif eweight_normalized == :standard + we_norm = 1/natoms + elseif eweight_normalized == :squared + we_norm = 1/natoms^2 + else + error("eweight_normalized can only be nothing, :standard, or :squared") + end + + W = Diagonal( [we_norm*ws[1]; + ws[2]*ones(length(ref_forces))] ) end AtWA .+= A'*W*A