Skip to content

Commit

Permalink
Small changes in NA example.
Browse files Browse the repository at this point in the history
  • Loading branch information
emmanuellujan committed Jun 17, 2024
1 parent e6bd81b commit 5e3e0ec
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 14 deletions.
47 changes: 33 additions & 14 deletions examples/DPP-ACE-Na/fit-dpp-ace-na.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
# # Subsample Na dataset with DPP and fit energies with ACE

# ## Load packages and define paths.

# Load packages
using Unitful, UnitfulAtomic
using AtomsBase, InteratomicPotentials, PotentialLearning
using LinearAlgebra, Plots

# Load dataset
# Define paths.
path = joinpath(dirname(pathof(PotentialLearning)), "../examples/DPP-ACE-Na")
confs, thermo = load_data("$path/../data/Na/liquify_sodium.yaml", YAML(:Na, u"eV", u""))
ds_path = "$path/../data/Na/liquify_sodium.yaml"

# ## Load atomistic dataset and split it into training and test.

# Load atomistic dataset: atomistic configurations (atom positions, geometry, etc.) + DFT data (energies, forces, etc.).
confs, thermo = load_data(ds_path, YAML(:Na, u"eV", u""))
confs, thermo = confs[220:end], thermo[220:end]

# Split dataset
# Split atomistic dataset into training and test.
conf_train, conf_test = confs[1:1000], confs[1001:end]

# Define ACE
# ## Create ACE basis, compute energy descriptors and add them to the dataset.

# Create ACE basis.
ace = ACE(species = [:Na], # species
body_order = 4, # 4-body
polynomial_degree = 8, # 8 degree polynomials
Expand All @@ -19,36 +31,43 @@ ace = ACE(species = [:Na], # species
r0 = 1.0, # minimum distance between atoms
rcutoff = 5.0) # cutoff radius

# Update training dataset by adding energy (local) descriptors
# Update training dataset by adding energy (local) descriptors.
println("Computing local descriptors of training dataset")
e_descr_train = compute_local_descriptors(conf_train, ace)
#e_descr_train = JLD.load("data/sodium_empirical_full.jld", "descriptors")
e_descr_train = compute_local_descriptors(conf_train, ace) # JLD.load("data/sodium_empirical_full.jld", "descriptors")

# Update training dataset by adding energy and force descriptors.
ds_train = DataSet(conf_train .+ e_descr_train)

# Learn using DPP
lb = LBasisPotential(ace)
# ## Subsampling via DPP.

# Create DPP subselector.
dpp = kDPP(ds_train, GlobalMean(), DotProduct(); batch_size = 200)

# Subsample trainig dataset.
dpp_inds = get_random_subset(dpp)

# ## Learn ACE coefficients based on ACE descriptors and DFT data.
lb = LBasisPotential(ace)
α = 1e-8
Σ = learn!(lb, ds_train[dpp_inds], α)

# Post-process output
# ## Post-process output: calculate metrics, create plots, and save results.

# Update test dataset by adding energy and force descriptors
# Update test dataset by adding energy descriptors.
println("Computing local descriptors of test dataset")
e_descr_test = compute_local_descriptors(conf_test, ace)
ds_test = DataSet(conf_test .+ e_descr_test)

# Get true and predicted energy values (assuming that all configurations have the same no. of atoms)
# Get true and predicted energy values (assuming that all configurations have the same no. of atoms).
n = size(get_system(ds_train[1]))[1]
e_train, e_train_pred = get_all_energies(ds_train)/n, get_all_energies(ds_train, lb)/n
e_test, e_test_pred = get_all_energies(ds_test)/n, get_all_energies(ds_test, lb)/n

# Compute and print metrics
# Compute and print metrics.
e_mae, e_rmse, e_rsq = calc_metrics(e_train, e_train_pred)
println("MAE: $e_mae, RMSE: $e_rmse, RSQ: $e_rsq")

# Plot energy error scatter
# Plot energy error.
e_err_train, e_err_test = (e_train_pred - e_train), (e_test_pred - e_test)
dpp_inds2 = get_random_subset(dpp; batch_size = 20)
p = scatter( e_train, e_err_train, label = "Training", color = :blue,
Expand Down
File renamed without changes.

0 comments on commit 5e3e0ec

Please sign in to comment.