From 4faa3fbdd85cb5d81f3b50b438d818f23d7ec9c2 Mon Sep 17 00:00:00 2001 From: Spencer Wyant <17836774+swyant@users.noreply.github.com> Date: Wed, 16 Oct 2024 15:55:10 -0400 Subject: [PATCH 1/3] initial ooc linear learn fn --- src/Learning/linear-learn.jl | 89 +++++++++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 1 deletion(-) diff --git a/src/Learning/linear-learn.jl b/src/Learning/linear-learn.jl index 341fb2e..872ce72 100644 --- a/src/Learning/linear-learn.jl +++ b/src/Learning/linear-learn.jl @@ -263,7 +263,7 @@ function learn!( else lp.β .= βs end - + end @@ -282,4 +282,91 @@ function learn!( return learn!(lp, ws, int) end +function assemble_matrices(lp, ws) + @views B_train = reduce(hcat, lp.B)' + @views dB_train = reduce(hcat, lp.dB)' + @views e_train = lp.e + @views f_train = reduce(vcat, lp.f) + + @views A = [B_train; dB_train] + @views b = [e_train; f_train] + + W = Diagonal([ws[1] * ones(length(e_train)); + ws[2] * ones(length(f_train))]) + + A, W, b +end + +function ooc_learn!( + lb::InteratomicPotentials.LinearBasisPotential, + ds_train::PotentialLearning.DataSet; + ws = [30.0,1.0], + symmetrize::Bool = true, + lambda::Union{Real,Nothing} = 0.01, + reg_style::Symbol = :default, + AtWA = nothing, + AtWb = nothing +) + + basis_size = length(lb.basis) + + if isnothing(AtWA) || isnothing(AtWb) + AtWA = zeros(basis_size,basis_size) + AtWb = zeros(basis_size) + + W = zeros(1,1) + + configs = get_system.(ds_train) + + for config in ds_train + ref_energy = get_values(get_energy(config)) + ref_forces = reduce(vcat,get_values(get_forces(config))) + + sys = get_system(config) + global_descrs = reshape(sum(compute_local_descriptors(sys,lb.basis)),:,1)' + force_descrs = stack(reduce(vcat,compute_force_descriptors(sys,lb.basis)))' + + A = [global_descrs; force_descrs] + b = [ref_energy; ref_forces] + if size(W)[1] != size(A)[1] + W = Diagonal( [ws[1]*ones(length(ref_energy)); + ws[2]*ones(length(ref_forces))]) + end + + AtWA .+= A'*W*A + AtWb .+= A'*W*b + end + else + AtWA = deepcopy(AtWA) + AtWb = deepcopy(AtWb) + end + + if symmetrize + AtWA = Symmetric(AtWA) + end + + if !isnothing(lambda) + if reg_style == :default + reg_matrix = lambda*Diagonal(ones(size(AtWA)[1])) + AtWA += reg_matrix + end + + if reg_style == :scale_thresh || reg_style == :scale + for i in 1:size(AtWA,1) + reg_elem = AtWA[i,i]*(1+lambda) + if reg_style == :scale_thresh + reg_elem = max(reg_elem,lambda) + end# + AtWA[i,i] = reg_elem + end + end + + end + β = AtWA \ AtWb + println("condition number of AtWA: $(cond(AtWA))") + + lb.β .= β + + AtWA, AtWb +end From 67a5aaa5f3f063d1932665600030396270aa91f9 Mon Sep 17 00:00:00 2001 From: Spencer Wyant <17836774+swyant@users.noreply.github.com> Date: Fri, 18 Oct 2024 11:15:26 -0400 Subject: [PATCH 2/3] adding regularization to standard wls routine --- src/Learning/linear-learn.jl | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/Learning/linear-learn.jl b/src/Learning/linear-learn.jl index 872ce72..081658b 100644 --- a/src/Learning/linear-learn.jl +++ b/src/Learning/linear-learn.jl @@ -226,7 +226,8 @@ Fit energies and forces using weighted least squares. function learn!( lp::CovariateLinearProblem, ws::Vector, - int::Bool + int::Bool; + λ::Real=0.0 ) @views B_train = reduce(hcat, lp.B)' @views dB_train = reduce(hcat, lp.dB)' @@ -249,11 +250,11 @@ function learn!( βs = Vector{Float64}() try - βs = (A'*Q*A) \ (A'*Q*b) + βs = (A'*Q*A + λ*I) \ (A'*Q*b) catch e println(e) println("Linear system will be solved using pinv.") - βs = pinv(A'*Q*A)*(A'*Q*b) + βs = pinv(A'*Q*A + λ*I)*(A'*Q*b) end # Update lp. @@ -302,7 +303,7 @@ function ooc_learn!( ds_train::PotentialLearning.DataSet; ws = [30.0,1.0], symmetrize::Bool = true, - lambda::Union{Real,Nothing} = 0.01, + λ::Union{Real,Nothing} = 0.01, reg_style::Symbol = :default, AtWA = nothing, AtWb = nothing @@ -345,17 +346,17 @@ function ooc_learn!( AtWA = Symmetric(AtWA) end - if !isnothing(lambda) + if !isnothing(λ) if reg_style == :default - reg_matrix = lambda*Diagonal(ones(size(AtWA)[1])) + reg_matrix = λ*Diagonal(ones(size(AtWA)[1])) AtWA += reg_matrix end if reg_style == :scale_thresh || reg_style == :scale for i in 1:size(AtWA,1) - reg_elem = AtWA[i,i]*(1+lambda) + reg_elem = AtWA[i,i]*(1+λ) if reg_style == :scale_thresh - reg_elem = max(reg_elem,lambda) + reg_elem = max(reg_elem,λ) end# AtWA[i,i] = reg_elem end From 60a20c50f745ad2612c826daafe11a89a354439b Mon Sep 17 00:00:00 2001 From: Spencer Wyant <17836774+swyant@users.noreply.github.com> Date: Fri, 18 Oct 2024 11:17:45 -0400 Subject: [PATCH 3/3] version bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fb4c8fd..d334c77 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PotentialLearning" uuid = "82b0a93c-c2e3-44bc-a418-f0f89b0ae5c2" authors = ["CESMIX Team"] -version = "0.2.6" +version = "0.2.7" [deps] AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a"