Skip to content

Commit

Permalink
speed up full-batch training by 15%
Browse files Browse the repository at this point in the history
  • Loading branch information
brucefan1983 committed Mar 3, 2023
1 parent c7f986a commit 90b77ba
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 21 deletions.
17 changes: 10 additions & 7 deletions src/main_nep/fitness.cu
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,19 @@ void Fitness::compute(
int population_iter = (para.population_size - 1) / deviceCount + 1;

if (generation == 0) {

std::vector<float> dummy_solution(para.number_of_variables * deviceCount, 1.0f);
for (int n = 0; n < num_batches; ++n) {
potential->find_force(para, dummy_solution.data(), train_set[n], true, deviceCount);
potential->find_force(para, dummy_solution.data(), train_set[n], true, true, deviceCount);
}

} else {
int batch_id = generation % num_batches;
bool calculate_neighbor = (num_batches > 1) || (generation % 100 == 0);
for (int n = 0; n < population_iter; ++n) {
const float* individual = population + deviceCount * n * para.number_of_variables;
potential->find_force(para, individual, train_set[batch_id], false, deviceCount);
potential->find_force(
para, individual, train_set[batch_id], false, calculate_neighbor, deviceCount);
for (int m = 0; m < deviceCount; ++m) {
float energy_shift_per_structure_not_used;
auto rmse_energy_array = train_set[batch_id][m].get_rmse_energy(
Expand Down Expand Up @@ -244,7 +247,7 @@ void Fitness::report_error(
{
if (0 == (generation + 1) % 100) {
int batch_id = generation % num_batches;
potential->find_force(para, elite, train_set[batch_id], false, 1);
potential->find_force(para, elite, train_set[batch_id], false, true, 1);
float energy_shift_per_structure;
auto rmse_energy_train_array =
train_set[batch_id][0].get_rmse_energy(para, energy_shift_per_structure, false, true, 0);
Expand All @@ -264,7 +267,7 @@ void Fitness::report_error(
float rmse_force_test = 0.0f;
float rmse_virial_test = 0.0f;
if (has_test_set) {
potential->find_force(para, elite, test_set, false, 1);
potential->find_force(para, elite, test_set, false, true, 1);
float energy_shift_per_structure_not_used;
auto rmse_energy_test_array =
test_set[0].get_rmse_energy(para, energy_shift_per_structure_not_used, false, false, 0);
Expand Down Expand Up @@ -396,7 +399,7 @@ void Fitness::predict(Parameters& para, float* elite)
FILE* fid_energy = my_fopen("energy_train.out", "w");
FILE* fid_virial = my_fopen("virial_train.out", "w");
for (int batch_id = 0; batch_id < num_batches; ++batch_id) {
potential->find_force(para, elite, train_set[batch_id], false, 1);
potential->find_force(para, elite, train_set[batch_id], false, true, 1);
update_energy_force_virial(fid_energy, fid_force, fid_virial, train_set[batch_id][0]);
}
fclose(fid_energy);
Expand All @@ -405,14 +408,14 @@ void Fitness::predict(Parameters& para, float* elite)
} else if (para.train_mode == 1) {
FILE* fid_dipole = my_fopen("dipole_train.out", "w");
for (int batch_id = 0; batch_id < num_batches; ++batch_id) {
potential->find_force(para, elite, train_set[batch_id], false, 1);
potential->find_force(para, elite, train_set[batch_id], false, true, 1);
update_dipole(fid_dipole, train_set[batch_id][0]);
}
fclose(fid_dipole);
} else if (para.train_mode == 2) {
FILE* fid_polarizability = my_fopen("polarizability_train.out", "w");
for (int batch_id = 0; batch_id < num_batches; ++batch_id) {
potential->find_force(para, elite, train_set[batch_id], false, 1);
potential->find_force(para, elite, train_set[batch_id], false, true, 1);
update_polarizability(fid_polarizability, train_set[batch_id][0]);
}
fclose(fid_polarizability);
Expand Down
32 changes: 18 additions & 14 deletions src/main_nep/nep3.cu
Original file line number Diff line number Diff line change
Expand Up @@ -743,10 +743,10 @@ static __global__ void find_force_ZBL(
find_f_and_fp_zbl(ZBL_para, zizj, a_inv, rc_inner, rc_outer, d12, d12inv, f, fp);
} else {
find_f_and_fp_zbl(zizj, a_inv, zbl.rc_inner, zbl.rc_outer, d12, d12inv, f, fp);
}
}
float f2 = fp * d12inv * 0.5f;
float f12[3] = {r12[0] * f2, r12[1] * f2, r12[2] * f2};

atomicAdd(&g_fx[n1], f12[0]);
atomicAdd(&g_fy[n1], f12[1]);
atomicAdd(&g_fz[n1], f12[2]);
Expand Down Expand Up @@ -776,6 +776,7 @@ void NEP3::find_force(
const float* parameters,
std::vector<Dataset>& dataset,
bool calculate_q_scaler,
bool calculate_neighbor,
int device_in_this_iter)
{
float rc2_radial = para.rc_radial * para.rc_radial;
Expand All @@ -792,18 +793,21 @@ void NEP3::find_force(
CHECK(cudaSetDevice(device_id));
const int block_size = 32;
const int grid_size = (dataset[device_id].N - 1) / block_size + 1;
gpu_find_neighbor_list<<<dataset[device_id].Nc, 256>>>(
dataset[device_id].N, dataset[device_id].Na.data(), dataset[device_id].Na_sum.data(),
rc2_radial, rc2_angular, dataset[device_id].box.data(),
dataset[device_id].box_original.data(), dataset[device_id].num_cell.data(),
dataset[device_id].r.data(), dataset[device_id].r.data() + dataset[device_id].N,
dataset[device_id].r.data() + dataset[device_id].N * 2, nep_data[device_id].NN_radial.data(),
nep_data[device_id].NL_radial.data(), nep_data[device_id].NN_angular.data(),
nep_data[device_id].NL_angular.data(), nep_data[device_id].x12_radial.data(),
nep_data[device_id].y12_radial.data(), nep_data[device_id].z12_radial.data(),
nep_data[device_id].x12_angular.data(), nep_data[device_id].y12_angular.data(),
nep_data[device_id].z12_angular.data());
CUDA_CHECK_KERNEL

if (calculate_neighbor) {
gpu_find_neighbor_list<<<dataset[device_id].Nc, 256>>>(
dataset[device_id].N, dataset[device_id].Na.data(), dataset[device_id].Na_sum.data(),
rc2_radial, rc2_angular, dataset[device_id].box.data(),
dataset[device_id].box_original.data(), dataset[device_id].num_cell.data(),
dataset[device_id].r.data(), dataset[device_id].r.data() + dataset[device_id].N,
dataset[device_id].r.data() + dataset[device_id].N * 2,
nep_data[device_id].NN_radial.data(), nep_data[device_id].NL_radial.data(),
nep_data[device_id].NN_angular.data(), nep_data[device_id].NL_angular.data(),
nep_data[device_id].x12_radial.data(), nep_data[device_id].y12_radial.data(),
nep_data[device_id].z12_radial.data(), nep_data[device_id].x12_angular.data(),
nep_data[device_id].y12_angular.data(), nep_data[device_id].z12_angular.data());
CUDA_CHECK_KERNEL
}

find_descriptors_radial<<<grid_size, block_size>>>(
dataset[device_id].N, nep_data[device_id].NN_radial.data(),
Expand Down
1 change: 1 addition & 0 deletions src/main_nep/nep3.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ public:
const float* parameters,
std::vector<Dataset>& dataset,
bool calculate_q_scaler,
bool calculate_neighbor,
int deviceCount);

private:
Expand Down
1 change: 1 addition & 0 deletions src/main_nep/potential.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ public:
const float* parameters,
std::vector<Dataset>& dataset,
bool calculate_q_scaler,
bool calculate_neighbor,
int DeviceCount) = 0;
};

0 comments on commit 90b77ba

Please sign in to comment.