Skip to content

Commit

Permalink
Merge b6db288 into a956d94
Browse files Browse the repository at this point in the history
  • Loading branch information
emmanuellujan authored Oct 6, 2023
2 parents a956d94 + b6db288 commit c114f59
Show file tree
Hide file tree
Showing 12 changed files with 52,564 additions and 13,270 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.2.0"

[deps]
AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a"
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
Determinantal = "2673d5e8-682c-11e9-2dfd-471b09c6c819"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Expand Down
51,987 changes: 51,987 additions & 0 deletions examples/DFT-subsampling/DPP_training/Hf/HfO2FPOD_training_analysis.pod

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
**************** Begin of Error Analysis for the Training Data Set ****************
---------------------------------------------------------------------------------------------------------------------------
File | # configs | # atoms | MAE energy | RMSE energy | MAE force | RMSE force
---------------------------------------------------------------------------------------------------------------------------
Hf128_MC_rattled_mp100_form_sorted.xyz 76 9728 0.005097 0.005721 0.128172 0.164395
Hf128_MC_rattled_mp103_form_sorted.xyz 12 1536 0.012707 0.017490 0.159344 0.205455
Hf128_MC_rattled_random_form_sorted.xyz 124 15872 0.023796 0.025652 0.295837 0.389306
Hf2_gas_form_sorted.xyz 63 126 0.063478 0.080521 0.087106 0.217110
Hf2_mp103_EOS_1D_form_sorted.xyz 91 182 0.043395 0.052563 0.000002 0.000005
Hf2_mp103_EOS_3D_form_sorted.xyz 4437 8874 0.028522 0.041463 0.064457 0.153331
Hf2_mp103_EOS_6D_form_sorted.xyz 8077 16154 0.031891 0.049949 0.094924 0.193149
HfO2_figshare_form_sorted.xyz 513 50750 0.004937 0.006500 0.124342 0.167539
HfO2_mp352_EOS_1D_form_sorted.xyz 112 1344 0.011358 0.012850 0.039550 0.053849
HfO2_mp550893_EOS_1D_form_sorted.xyz 65 195 0.019022 0.026343 0.000090 0.000508
HfO2_mp550893_EOS_6D_form_sorted.xyz 12765 38295 0.024245 0.034777 0.059607 0.130631
HfO_EOS_6D_form_sorted.xyz 7237 14474 0.022218 0.031917 0.001420 0.018106
HfO_gas_form_sorted.xyz 129 258 0.007752 0.024816 0.028570 0.065918
Hf_mp100_EOS_1D_form_sorted.xyz 37 37 0.057238 0.076505 0.000000 0.000000
Hf_mp100_EOS_3D_form_sorted.xyz 1691 1691 0.032657 0.049885 0.000000 0.000000
Hf_mp100_EOS_6D_form_sorted.xyz 2996 2996 0.023730 0.034286 0.000000 0.000000
Hf_mp100_primitive_EOS_1D_form_sorted.xyz 37 37 0.052678 0.067729 0.000000 0.000000
Hf_mp100_primitive_EOS_3D_form_sorted.xyz 1665 1665 0.026263 0.037030 0.000000 0.000000
O2_gas_form_sorted.xyz 204 408 0.063950 0.083066 0.075771 0.190771
O2_mp607540_EOS_6D_form_sorted.xyz 8253 16506 0.075737 0.111845 0.353885 0.726441
O_EOS_6D_form_sorted.xyz 3361 3361 0.070968 0.094037 0.000000 0.000000
---------------------------------------------------------------------------------------------------------------------------
All files 51945 184489 0.037013 0.061469 0.123854 0.279027
---------------------------------------------------------------------------------------------------------------------------
**************** End of Error Analysis for the Training Data Set ****************
86 changes: 86 additions & 0 deletions examples/DFT-subsampling/subsampling_dpp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
push!(Base.LOAD_PATH, "../../")

using PotentialLearning
using LinearAlgebra, Random, Statistics, StatsBase, Distributions
using AtomsBase, Unitful, UnitfulAtomic
using InteratomicPotentials, InteratomicBasisPotentials
using Determinantal
using CairoMakie
using InvertedIndices
using CSV
using JLD
using DataFrames

include("subsampling_utils.jl")


# load dataset -------------------------------------------------------------------
elname = "Hf"
elspec = [:Hf]
inpath = "./DFT_data/$elname/"
outpath = "./DPP_training/$elname/"

# read all data
file_arr = readext(inpath, "xyz")
nfile = length(file_arr)
confs_arr = [load_data(inpath*file, ExtXYZ(u"eV", u"")) for file in file_arr]
confs = concat_dataset(confs_arr)

# id of configurations per file
n = 0
confs_id = Vector{Vector{Int64}}(undef, nfile)
for k = 1:nfile
confs_id[k] = (n+1):(n+length(confs_arr[k]))
n += length(confs_arr[k])
end

# read single file
# datafile = "Hf_mp100_EOS_1D_form_sorted.xyz"
# confs = load_data(inpath*datafile, ExtXYZ(u"eV", u"Å"))


# define ACE basis ---------------------------------------------------------------
nbody = 4
deg = 5
ace = ACE(species = elspec, # species
body_order = nbody, # n-body
polynomial_degree = deg, # degree of polynomials
wL = 1.0, # Defaults, See ACE.jl documentation
csp = 1.0, # Defaults, See ACE.jl documentation
r0 = 1.0, # minimum distance between atoms
rcutoff = 10.0)


# Update dataset by adding energy (local) descriptors ----------------------------
println("Computing local descriptors")
@time e_descr = compute_local_descriptors(confs, ace)
@time f_descr = compute_force_descriptors(confs, ace)
JLD.save(outpath*"Hf_energy_descriptors.jld", "e_descr", e_descr)
JLD.save(outpath*"Hf_force_descriptors.jld", "f_descr", f_descr)

ds = DataSet(confs .+ e_descr .+ f_descr)
ndata = length(ds)

# compute cross validation error from training -----------------------------------

batch_size = [2000, 1000, 500, 200, 100]
# batch_size = [80, 40, 20]
sel_ind = Dict{Int64, Vector}()
cond_num = Dict{Int64, Vector}()

for bs in batch_size
println("=============== starting batch size $bs ===============")
sel_ind[bs], cond_num[bs] = cross_validation_training(ds; ndiv=5, dpp_batch=bs)
end

JLD.save(outpath*"$(elname)_ACE-$(nbody)-$(deg)_DPP_indices_and_condnum.jld",
"ind", sel_ind,
"condnum", cond_num)








134 changes: 134 additions & 0 deletions examples/DFT-subsampling/subsampling_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
function readext(path::String, ext::String)
dir = readdir(path)
substr = [split(f, ".") for f in dir]
id = findall(x -> x[end] == ext, substr)
return dir[id]
end


function concat_dataset(confs::Vector{DataSet})
N = length(confs)
confs_vec = [[confs[i][j] for j = 1:length(confs[i])] for i = 1:N]
confs_all = reduce(vcat, confs_vec)
return DataSet(confs_all)
end


function train_potential(ds_train::DataSet, ace::ACE, dpp_batch::Int64)
# learn with DPP
lb = LBasisPotentialExt(ace)
dpp = kDPP(ds_train, GlobalMean(), DotProduct(); batch_size = dpp_batch)
dpp_inds = get_random_subset(dpp)
lp = learn!(lb, ds_train[dpp_inds], [100, 1], false)

cond_num = compute_cond_num(lp)
return lb, cond_num, dpp_inds
end


function compute_cond_num(lp::PotentialLearning.UnivariateLinearProblem)
A = reduce(hcat, lp.iv_data)'
return cond(A)
end


function compute_cond_num(lp::PotentialLearning.CovariateLinearProblem)
B = reduce(hcat, lp.B)'
dB = reduce(hcat, lp.dB)'
A = [B; dB]
return cond(A)
end


function cross_validation_training(ds; ndiv=10,
dpp_batch=Int(floor(2*length(ds) / ndiv))
)

# retrieve DFT data
energies = get_all_energies(ds)
forces = reduce(vcat,[sum(norm.(get_values(get_forces(ds[c]))))
for c in 1:length(ds)]) # magnitude

# init arrays
cond_num = zeros(ndiv)
e_err = Matrix{Float64}(undef, (length(ds), ndiv))
f_err = Matrix{Float64}(undef, (length(ds), ndiv))
e_mae, e_rmse = Dict(f => zeros(ndiv) for f in file_arr), Dict(f => zeros(ndiv) for f in file_arr)
f_mae, f_rmse = Dict(f => zeros(ndiv) for f in file_arr), Dict(f => zeros(ndiv) for f in file_arr)
sel_ind = Vector{Vector{Int64}}(undef, ndiv)

# make random divisions of data
ndata = length(ds)
ncut = Int(floor(ndata / ndiv))
ind_all = rand(1:ndata, ndata)
ind = [ind_all[(k*ncut+1):((k+1)*ncut)] for k = 0:(ndiv-1)]

# iterate over divisions
for i = 1:ndiv
println("batch $i")
# split train/test sets
train_ind = reduce(vcat, ind[Not(i)])
ds_train = ds[train_ind]

# train using dpp
lb, cond_num[i], dpp_ind = train_potential(ds_train, ace, dpp_batch)
sel_ind[i] = train_ind[dpp_ind]

# get predicted energies
e_pred = get_all_energies(ds, lb)
f_pred = get_all_forces_mag(ds, lb)
e_err[:,i] = energies - e_pred
f_err[:,i] = forces - f_pred

for j = 1:nfile
e_mae[file_arr[j]][i], e_rmse[file_arr[j]][i], _ = calc_metrics(energies[confs_id[j]], e_pred[confs_id[j]])
f_mae[file_arr[j]][i], f_rmse[file_arr[j]][i], _ = calc_metrics(forces[confs_id[j]], f_pred[confs_id[j]])
end
end

# populate DataFrame
df_conf = DataFrame("config" => 1:length(ds),
"file" => reduce(vcat, [[file_arr[j] for i = 1:length(confs_arr[j])] for j = 1:nfile]),
"DFT energy" => energies,
"energy err mean" => mean(e_err, dims=2)[:], # mean error from k-fold CV
"energy err std" => std(e_err, dims=2)[:], # std of error
"DFT force" => forces,
"force err mean" => mean(f_err, dims=2)[:], # mean error from k-fold CV
"force err std" => std(f_err, dims=2)[:], # std of error
)

df_meta = DataFrame("file" => file_arr,
"# configs" => [length(conf) for conf in confs_id],
"E mae mean" => [mean(e_mae[k]) for k in keys(e_mae)],
"E mae std" => [std(e_mae[k]) for k in keys(e_mae)],
"E rmse mean" => [mean(e_rmse[k]) for k in keys(e_rmse)],
"E rmse std" => [std(e_rmse[k]) for k in keys(e_rmse)],
"F mae mean" => [mean(f_mae[k]) for k in keys(f_mae)],
"F mae std" => [std(f_mae[k]) for k in keys(f_mae)],
"F rmse mean" => [mean(f_rmse[k]) for k in keys(f_rmse)],
"F rmse std" => [std(f_rmse[k]) for k in keys(f_rmse)],
)

# write to file
CSV.write(outpath*"$(elname)_ACE-$(nbody)-$(deg)_train_full_nDPP=$(dpp_batch).csv", df_conf)
CSV.write(outpath*"$(elname)_ACE-$(nbody)-$(deg)_train_metadata_nDPP=$(dpp_batch).csv", df_meta)

return sel_ind, cond_num
end


# function get_all_forces_mag(ds::DataSet)
# for c in ds
# force_coord = reduce(hcat, get_values(get_forces(c)))'
# sum(force_coord, dims=1)



function get_all_forces_mag(
ds::DataSet,
lb::PotentialLearning.LinearBasisPotential
)
force_descriptors = [reduce(vcat, get_values(get_force_descriptors(dsi)) ) for dsi in ds]
force_pred = [lb.β0[1] .+ dB' * lb.β for dB in [reduce(hcat, fi) for fi in force_descriptors]]
return [sum(norm.([f[k:k+2] for k = 1:3:length(f)])) for f in force_pred]
end
Loading

0 comments on commit c114f59

Please sign in to comment.