diff --git a/src/main_nep/fitness.cu b/src/main_nep/fitness.cu index 666593b2d..cc70704ef 100644 --- a/src/main_nep/fitness.cu +++ b/src/main_nep/fitness.cu @@ -132,16 +132,19 @@ void Fitness::compute( int population_iter = (para.population_size - 1) / deviceCount + 1; if (generation == 0) { + std::vector dummy_solution(para.number_of_variables * deviceCount, 1.0f); for (int n = 0; n < num_batches; ++n) { - potential->find_force(para, dummy_solution.data(), train_set[n], true, deviceCount); + potential->find_force(para, dummy_solution.data(), train_set[n], true, true, deviceCount); } } else { int batch_id = generation % num_batches; + bool calculate_neighbor = (num_batches > 1) || (generation % 100 == 0); for (int n = 0; n < population_iter; ++n) { const float* individual = population + deviceCount * n * para.number_of_variables; - potential->find_force(para, individual, train_set[batch_id], false, deviceCount); + potential->find_force( + para, individual, train_set[batch_id], false, calculate_neighbor, deviceCount); for (int m = 0; m < deviceCount; ++m) { float energy_shift_per_structure_not_used; auto rmse_energy_array = train_set[batch_id][m].get_rmse_energy( @@ -244,7 +247,7 @@ void Fitness::report_error( { if (0 == (generation + 1) % 100) { int batch_id = generation % num_batches; - potential->find_force(para, elite, train_set[batch_id], false, 1); + potential->find_force(para, elite, train_set[batch_id], false, true, 1); float energy_shift_per_structure; auto rmse_energy_train_array = train_set[batch_id][0].get_rmse_energy(para, energy_shift_per_structure, false, true, 0); @@ -264,7 +267,7 @@ void Fitness::report_error( float rmse_force_test = 0.0f; float rmse_virial_test = 0.0f; if (has_test_set) { - potential->find_force(para, elite, test_set, false, 1); + potential->find_force(para, elite, test_set, false, true, 1); float energy_shift_per_structure_not_used; auto rmse_energy_test_array = test_set[0].get_rmse_energy(para, energy_shift_per_structure_not_used, false, false, 0); @@ -396,7 +399,7 @@ void Fitness::predict(Parameters& para, float* elite) FILE* fid_energy = my_fopen("energy_train.out", "w"); FILE* fid_virial = my_fopen("virial_train.out", "w"); for (int batch_id = 0; batch_id < num_batches; ++batch_id) { - potential->find_force(para, elite, train_set[batch_id], false, 1); + potential->find_force(para, elite, train_set[batch_id], false, true, 1); update_energy_force_virial(fid_energy, fid_force, fid_virial, train_set[batch_id][0]); } fclose(fid_energy); @@ -405,14 +408,14 @@ void Fitness::predict(Parameters& para, float* elite) } else if (para.train_mode == 1) { FILE* fid_dipole = my_fopen("dipole_train.out", "w"); for (int batch_id = 0; batch_id < num_batches; ++batch_id) { - potential->find_force(para, elite, train_set[batch_id], false, 1); + potential->find_force(para, elite, train_set[batch_id], false, true, 1); update_dipole(fid_dipole, train_set[batch_id][0]); } fclose(fid_dipole); } else if (para.train_mode == 2) { FILE* fid_polarizability = my_fopen("polarizability_train.out", "w"); for (int batch_id = 0; batch_id < num_batches; ++batch_id) { - potential->find_force(para, elite, train_set[batch_id], false, 1); + potential->find_force(para, elite, train_set[batch_id], false, true, 1); update_polarizability(fid_polarizability, train_set[batch_id][0]); } fclose(fid_polarizability); diff --git a/src/main_nep/nep3.cu b/src/main_nep/nep3.cu index b9a27afa8..eabd1d7d0 100644 --- a/src/main_nep/nep3.cu +++ b/src/main_nep/nep3.cu @@ -743,10 +743,10 @@ static __global__ void find_force_ZBL( find_f_and_fp_zbl(ZBL_para, zizj, a_inv, rc_inner, rc_outer, d12, d12inv, f, fp); } else { find_f_and_fp_zbl(zizj, a_inv, zbl.rc_inner, zbl.rc_outer, d12, d12inv, f, fp); - } + } float f2 = fp * d12inv * 0.5f; float f12[3] = {r12[0] * f2, r12[1] * f2, r12[2] * f2}; - + atomicAdd(&g_fx[n1], f12[0]); atomicAdd(&g_fy[n1], f12[1]); atomicAdd(&g_fz[n1], f12[2]); @@ -776,6 +776,7 @@ void NEP3::find_force( const float* parameters, std::vector& dataset, bool calculate_q_scaler, + bool calculate_neighbor, int device_in_this_iter) { float rc2_radial = para.rc_radial * para.rc_radial; @@ -792,18 +793,21 @@ void NEP3::find_force( CHECK(cudaSetDevice(device_id)); const int block_size = 32; const int grid_size = (dataset[device_id].N - 1) / block_size + 1; - gpu_find_neighbor_list<<>>( - dataset[device_id].N, dataset[device_id].Na.data(), dataset[device_id].Na_sum.data(), - rc2_radial, rc2_angular, dataset[device_id].box.data(), - dataset[device_id].box_original.data(), dataset[device_id].num_cell.data(), - dataset[device_id].r.data(), dataset[device_id].r.data() + dataset[device_id].N, - dataset[device_id].r.data() + dataset[device_id].N * 2, nep_data[device_id].NN_radial.data(), - nep_data[device_id].NL_radial.data(), nep_data[device_id].NN_angular.data(), - nep_data[device_id].NL_angular.data(), nep_data[device_id].x12_radial.data(), - nep_data[device_id].y12_radial.data(), nep_data[device_id].z12_radial.data(), - nep_data[device_id].x12_angular.data(), nep_data[device_id].y12_angular.data(), - nep_data[device_id].z12_angular.data()); - CUDA_CHECK_KERNEL + + if (calculate_neighbor) { + gpu_find_neighbor_list<<>>( + dataset[device_id].N, dataset[device_id].Na.data(), dataset[device_id].Na_sum.data(), + rc2_radial, rc2_angular, dataset[device_id].box.data(), + dataset[device_id].box_original.data(), dataset[device_id].num_cell.data(), + dataset[device_id].r.data(), dataset[device_id].r.data() + dataset[device_id].N, + dataset[device_id].r.data() + dataset[device_id].N * 2, + nep_data[device_id].NN_radial.data(), nep_data[device_id].NL_radial.data(), + nep_data[device_id].NN_angular.data(), nep_data[device_id].NL_angular.data(), + nep_data[device_id].x12_radial.data(), nep_data[device_id].y12_radial.data(), + nep_data[device_id].z12_radial.data(), nep_data[device_id].x12_angular.data(), + nep_data[device_id].y12_angular.data(), nep_data[device_id].z12_angular.data()); + CUDA_CHECK_KERNEL + } find_descriptors_radial<<>>( dataset[device_id].N, nep_data[device_id].NN_radial.data(), diff --git a/src/main_nep/nep3.cuh b/src/main_nep/nep3.cuh index 301fa2080..6d6beeafe 100644 --- a/src/main_nep/nep3.cuh +++ b/src/main_nep/nep3.cuh @@ -99,6 +99,7 @@ public: const float* parameters, std::vector& dataset, bool calculate_q_scaler, + bool calculate_neighbor, int deviceCount); private: diff --git a/src/main_nep/potential.cuh b/src/main_nep/potential.cuh index 9b13cb87d..395026e80 100644 --- a/src/main_nep/potential.cuh +++ b/src/main_nep/potential.cuh @@ -28,5 +28,6 @@ public: const float* parameters, std::vector& dataset, bool calculate_q_scaler, + bool calculate_neighbor, int DeviceCount) = 0; };