Skip to content

Commit

Permalink
Merge pull request #53 from cesmix-mit/dbscan-subsampling
Browse files Browse the repository at this point in the history
DBSCAN subsampling
  • Loading branch information
emmanuellujan authored Oct 14, 2023
2 parents a956d94 + 82fc50c commit fe8273f
Show file tree
Hide file tree
Showing 10 changed files with 541 additions and 4 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.2.0"

[deps]
AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a"
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
Determinantal = "2673d5e8-682c-11e9-2dfd-471b09c6c819"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Expand Down
75 changes: 75 additions & 0 deletions examples/DFT-subsampling/subsampling_dpp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
push!(Base.LOAD_PATH, "../../")

using PotentialLearning
using LinearAlgebra, Random, Statistics, StatsBase, Distributions
using AtomsBase, Unitful, UnitfulAtomic
using InteratomicPotentials, InteratomicBasisPotentials
using Determinantal
using CairoMakie
using InvertedIndices
using CSV
using JLD
using DataFrames

include("subsampling_utils.jl")

# Load dataset -----------------------------------------------------------------
elname = "Si"
elspec = [:Si]
inpath = "../Si-3Body-LAMMPS/"
outpath = "./output/$elname/"

# Read all data
file_arr = readext(inpath, "xyz")
nfile = length(file_arr)
confs_arr = [load_data(inpath*file, ExtXYZ(u"eV", u"")) for file in file_arr]
confs = concat_dataset(confs_arr)

# Id of configurations per file
n = 0
confs_id = Vector{Vector{Int64}}(undef, nfile)
for k = 1:nfile
global n
confs_id[k] = (n+1):(n+length(confs_arr[k]))
n += length(confs_arr[k])
end

# Read single file
# datafile = "Hf_mp100_EOS_1D_form_sorted.xyz"
# confs = load_data(inpath*datafile, ExtXYZ(u"eV", u"Å"))

# Define ACE basis -------------------------------------------------------------
nbody = 4
deg = 5
ace = ACE(species = elspec, # species
body_order = nbody, # n-body
polynomial_degree = deg, # degree of polynomials
wL = 1.0, # Defaults, See ACE.jl documentation
csp = 1.0, # Defaults, See ACE.jl documentation
r0 = 1.0, # minimum distance between atoms
rcutoff = 10.0)

# Update dataset by adding energy (local) descriptors --------------------------
println("Computing local descriptors")
@time e_descr = compute_local_descriptors(confs, ace)
@time f_descr = compute_force_descriptors(confs, ace)
JLD.save(outpath*"$(elname)_energy_descriptors.jld", "e_descr", e_descr)
JLD.save(outpath*"$(elname)_force_descriptors.jld", "f_descr", f_descr)

ds = DataSet(confs .+ e_descr .+ f_descr)
ndata = length(ds)

# Compute cross validation error from training ---------------------------------
batch_size = [80, 40, 20]
sel_ind = Dict{Int64, Vector}()
cond_num = Dict{Int64, Vector}()

for bs in batch_size
println("=============== Starting batch size $bs ===============")
sel_ind[bs], cond_num[bs] = cross_validation_training(ds; ndiv=5, dpp_batch=bs)
end

JLD.save(outpath*"$(elname)_ACE-$(nbody)-$(deg)_DPP_indices_and_condnum.jld",
"ind", sel_ind,
"condnum", cond_num)

134 changes: 134 additions & 0 deletions examples/DFT-subsampling/subsampling_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
function readext(path::String, ext::String)
dir = readdir(path)
substr = [split(f, ".") for f in dir]
id = findall(x -> x[end] == ext, substr)
return dir[id]
end


function concat_dataset(confs::Vector{DataSet})
N = length(confs)
confs_vec = [[confs[i][j] for j = 1:length(confs[i])] for i = 1:N]
confs_all = reduce(vcat, confs_vec)
return DataSet(confs_all)
end


function train_potential(ds_train::DataSet, ace::ACE, dpp_batch::Int64)
# learn with DPP
lb = LBasisPotentialExt(ace)
dpp = kDPP(ds_train, GlobalMean(), DotProduct(); batch_size = dpp_batch)
dpp_inds = get_random_subset(dpp)
lp = learn!(lb, ds_train[dpp_inds], [100, 1], false)

cond_num = compute_cond_num(lp)
return lb, cond_num, dpp_inds
end


function compute_cond_num(lp::PotentialLearning.UnivariateLinearProblem)
A = reduce(hcat, lp.iv_data)'
return cond(A)
end


function compute_cond_num(lp::PotentialLearning.CovariateLinearProblem)
B = reduce(hcat, lp.B)'
dB = reduce(hcat, lp.dB)'
A = [B; dB]
return cond(A)
end


function cross_validation_training(ds; ndiv=10,
dpp_batch=Int(floor(2*length(ds) / ndiv))
)

# retrieve DFT data
energies = get_all_energies(ds)
forces = reduce(vcat,[sum(norm.(get_values(get_forces(ds[c]))))
for c in 1:length(ds)]) # magnitude

# init arrays
cond_num = zeros(ndiv)
e_err = Matrix{Float64}(undef, (length(ds), ndiv))
f_err = Matrix{Float64}(undef, (length(ds), ndiv))
e_mae, e_rmse = Dict(f => zeros(ndiv) for f in file_arr), Dict(f => zeros(ndiv) for f in file_arr)
f_mae, f_rmse = Dict(f => zeros(ndiv) for f in file_arr), Dict(f => zeros(ndiv) for f in file_arr)
sel_ind = Vector{Vector{Int64}}(undef, ndiv)

# make random divisions of data
ndata = length(ds)
ncut = Int(floor(ndata / ndiv))
ind_all = rand(1:ndata, ndata)
ind = [ind_all[(k*ncut+1):((k+1)*ncut)] for k = 0:(ndiv-1)]

# iterate over divisions
for i = 1:ndiv
println("batch $i")
# split train/test sets
train_ind = reduce(vcat, ind[Not(i)])
ds_train = ds[train_ind]

# train using dpp
lb, cond_num[i], dpp_ind = train_potential(ds_train, ace, dpp_batch)
sel_ind[i] = train_ind[dpp_ind]

# get predicted energies
e_pred = get_all_energies(ds, lb)
f_pred = get_all_forces_mag(ds, lb)
e_err[:,i] = energies - e_pred
f_err[:,i] = forces - f_pred

for j = 1:nfile
e_mae[file_arr[j]][i], e_rmse[file_arr[j]][i], _ = calc_metrics(energies[confs_id[j]], e_pred[confs_id[j]])
f_mae[file_arr[j]][i], f_rmse[file_arr[j]][i], _ = calc_metrics(forces[confs_id[j]], f_pred[confs_id[j]])
end
end

# populate DataFrame
df_conf = DataFrame("config" => 1:length(ds),
"file" => reduce(vcat, [[file_arr[j] for i = 1:length(confs_arr[j])] for j = 1:nfile]),
"DFT energy" => energies,
"energy err mean" => mean(e_err, dims=2)[:], # mean error from k-fold CV
"energy err std" => std(e_err, dims=2)[:], # std of error
"DFT force" => forces,
"force err mean" => mean(f_err, dims=2)[:], # mean error from k-fold CV
"force err std" => std(f_err, dims=2)[:], # std of error
)

df_meta = DataFrame("file" => file_arr,
"# configs" => [length(conf) for conf in confs_id],
"E mae mean" => [mean(e_mae[k]) for k in keys(e_mae)],
"E mae std" => [std(e_mae[k]) for k in keys(e_mae)],
"E rmse mean" => [mean(e_rmse[k]) for k in keys(e_rmse)],
"E rmse std" => [std(e_rmse[k]) for k in keys(e_rmse)],
"F mae mean" => [mean(f_mae[k]) for k in keys(f_mae)],
"F mae std" => [std(f_mae[k]) for k in keys(f_mae)],
"F rmse mean" => [mean(f_rmse[k]) for k in keys(f_rmse)],
"F rmse std" => [std(f_rmse[k]) for k in keys(f_rmse)],
)

# write to file
CSV.write(outpath*"$(elname)_ACE-$(nbody)-$(deg)_train_full_nDPP=$(dpp_batch).csv", df_conf)
CSV.write(outpath*"$(elname)_ACE-$(nbody)-$(deg)_train_metadata_nDPP=$(dpp_batch).csv", df_meta)

return sel_ind, cond_num
end


# function get_all_forces_mag(ds::DataSet)
# for c in ds
# force_coord = reduce(hcat, get_values(get_forces(c)))'
# sum(force_coord, dims=1)



function get_all_forces_mag(
ds::DataSet,
lb::PotentialLearning.LinearBasisPotential
)
force_descriptors = [reduce(vcat, get_values(get_force_descriptors(dsi)) ) for dsi in ds]
force_pred = [lb.β0[1] .+ dB' * lb.β for dB in [reduce(hcat, fi) for fi in force_descriptors]]
return [sum(norm.([f[k:k+2] for k = 1:3:length(f)])) for f in force_pred]
end
4 changes: 4 additions & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
[deps]
ACE1pack = "8c4e8d19-0bd6-4234-8309-7210652e3178"
AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Determinantal = "2673d5e8-682c-11e9-2dfd-471b09c6c819"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Expand All @@ -12,13 +14,15 @@ GeometricFlux = "7e08b658-56d3-11e9-2997-919d5b31e4ea"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
InteratomicBasisPotentials = "37c59853-c2ad-4e3a-930c-a41b2395fb19"
InteratomicPotentials = "a9efe35a-c65d-452d-b8a8-82646cd5cb04"
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
JLD = "4138dd39-2aa7-5051-a626-17a0bb65d9c8"
JuLIP = "945c410c-986d-556a-acb1-167a618e0462"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
Expand Down
8 changes: 8 additions & 0 deletions src/Data/datatypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ Abstract type declaring the type of information that is unique to a particular a
abstract type AtomicData <: Data end

CFG_TYPE = Union{AtomsBase.FlexibleSystem,ConfigurationData}

"""
get_values(v::SVector)
Removes units from a position.
"""
get_values(v::SVector) = ustrip(v)

"""
Energy <: ConfigurationData
d :: Real
Expand Down
3 changes: 2 additions & 1 deletion src/Learning/linear-learning-problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ function learn!(
)
lp = LinearProblem(ds)
learn!(lp, args...)

copy!(iap.β, lp.β)
copy!(iap.β0, lp.β0)
return lp.Σ
return lp
end
Loading

0 comments on commit fe8273f

Please sign in to comment.