Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow for choosing shear virial weight #320

Merged
merged 3 commits into from
Nov 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 16 additions & 31 deletions src/main_nep/dataset.cu
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ float Dataset::get_rmse_energy(
return sqrt(error_ave / Nc);
}

float Dataset::get_rmse_virial(const bool use_weight, int device_id)
float Dataset::get_rmse_virial(Parameters& para, const bool use_weight, int device_id)
{
CHECK(cudaSetDevice(device_id));
int num_virial_configurations = 0;
Expand All @@ -491,11 +491,8 @@ float Dataset::get_rmse_virial(const bool use_weight, int device_id)
CHECK(cudaMemcpy(error_cpu.data(), error_gpu.data(), mem, cudaMemcpyDeviceToHost));
for (int n = 0; n < Nc; ++n) {
if (structures[n].has_virial) {
if (use_weight) {
error_ave += weight_cpu[n] * weight_cpu[n] * error_cpu[n];
} else {
error_ave += error_cpu[n];
}
float total_weight = use_weight ? weight_cpu[n] * weight_cpu[n] : 1.0f;
error_ave += total_weight * error_cpu[n];
}
}

Expand All @@ -505,11 +502,8 @@ float Dataset::get_rmse_virial(const bool use_weight, int device_id)
CHECK(cudaMemcpy(error_cpu.data(), error_gpu.data(), mem, cudaMemcpyDeviceToHost));
for (int n = 0; n < Nc; ++n) {
if (structures[n].has_virial) {
if (use_weight) {
error_ave += weight_cpu[n] * weight_cpu[n] * error_cpu[n];
} else {
error_ave += error_cpu[n];
}
float total_weight = use_weight ? weight_cpu[n] * weight_cpu[n] : 1.0f;
error_ave += total_weight * error_cpu[n];
}
}

Expand All @@ -519,11 +513,8 @@ float Dataset::get_rmse_virial(const bool use_weight, int device_id)
CHECK(cudaMemcpy(error_cpu.data(), error_gpu.data(), mem, cudaMemcpyDeviceToHost));
for (int n = 0; n < Nc; ++n) {
if (structures[n].has_virial) {
if (use_weight) {
error_ave += weight_cpu[n] * weight_cpu[n] * error_cpu[n];
} else {
error_ave += error_cpu[n];
}
float total_weight = use_weight ? weight_cpu[n] * weight_cpu[n] : 1.0f;
error_ave += total_weight * error_cpu[n];
}
}

Expand All @@ -533,11 +524,9 @@ float Dataset::get_rmse_virial(const bool use_weight, int device_id)
CHECK(cudaMemcpy(error_cpu.data(), error_gpu.data(), mem, cudaMemcpyDeviceToHost));
for (int n = 0; n < Nc; ++n) {
if (structures[n].has_virial) {
if (use_weight) {
error_ave += weight_cpu[n] * weight_cpu[n] * error_cpu[n];
} else {
error_ave += error_cpu[n];
}
float total_weight =
use_weight ? weight_cpu[n] * weight_cpu[n] * para.lambda_shear * para.lambda_shear : 1.0f;
error_ave += total_weight * error_cpu[n];
}
}

Expand All @@ -547,11 +536,9 @@ float Dataset::get_rmse_virial(const bool use_weight, int device_id)
CHECK(cudaMemcpy(error_cpu.data(), error_gpu.data(), mem, cudaMemcpyDeviceToHost));
for (int n = 0; n < Nc; ++n) {
if (structures[n].has_virial) {
if (use_weight) {
error_ave += weight_cpu[n] * weight_cpu[n] * error_cpu[n];
} else {
error_ave += error_cpu[n];
}
float total_weight =
use_weight ? weight_cpu[n] * weight_cpu[n] * para.lambda_shear * para.lambda_shear : 1.0f;
error_ave += total_weight * error_cpu[n];
}
}

Expand All @@ -561,11 +548,9 @@ float Dataset::get_rmse_virial(const bool use_weight, int device_id)
CHECK(cudaMemcpy(error_cpu.data(), error_gpu.data(), mem, cudaMemcpyDeviceToHost));
for (int n = 0; n < Nc; ++n) {
if (structures[n].has_virial) {
if (use_weight) {
error_ave += weight_cpu[n] * weight_cpu[n] * error_cpu[n];
} else {
error_ave += error_cpu[n];
}
float total_weight =
use_weight ? weight_cpu[n] * weight_cpu[n] * para.lambda_shear * para.lambda_shear : 1.0f;
error_ave += total_weight * error_cpu[n];
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/main_nep/dataset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public:
float get_rmse_force(Parameters& para, const bool use_weight, int device_id);
float get_rmse_energy(
float& energy_shift_per_structure, const bool use_weight, const bool do_shift, int device_id);
float get_rmse_virial(const bool use_weight, int device_id);
float get_rmse_virial(Parameters& para, const bool use_weight, int device_id);

private:
void copy_structures(std::vector<Structure>& structures_input, int n1, int n2);
Expand Down
6 changes: 3 additions & 3 deletions src/main_nep/fitness.cu
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ void Fitness::compute(
fitness[deviceCount * n + m + 1 * para.population_size] =
para.lambda_f * train_set[batch_id][m].get_rmse_force(para, true, m);
fitness[deviceCount * n + m + 2 * para.population_size] =
para.lambda_v * train_set[batch_id][m].get_rmse_virial(true, m);
para.lambda_v * train_set[batch_id][m].get_rmse_virial(para, true, m);
}
}
}
Expand Down Expand Up @@ -225,7 +225,7 @@ void Fitness::report_error(
float rmse_energy_train =
train_set[batch_id][0].get_rmse_energy(energy_shift_per_structure, false, true, 0);
float rmse_force_train = train_set[batch_id][0].get_rmse_force(para, false, 0);
float rmse_virial_train = train_set[batch_id][0].get_rmse_virial(false, 0);
float rmse_virial_train = train_set[batch_id][0].get_rmse_virial(para, false, 0);

// correct the last bias parameter in the NN
if (para.train_mode == 0) {
Expand All @@ -237,7 +237,7 @@ void Fitness::report_error(
float rmse_energy_test =
test_set[0].get_rmse_energy(energy_shift_per_structure_not_used, false, false, 0);
float rmse_force_test = test_set[0].get_rmse_force(para, false, 0);
float rmse_virial_test = test_set[0].get_rmse_virial(false, 0);
float rmse_virial_test = test_set[0].get_rmse_virial(para, false, 0);

FILE* fid_nep = my_fopen("nep.txt", "w");
write_nep_txt(fid_nep, para, elite);
Expand Down
29 changes: 29 additions & 0 deletions src/main_nep/parameters.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ void Parameters::set_default_parameters()
is_lambda_e_set = false;
is_lambda_f_set = false;
is_lambda_v_set = false;
is_lambda_shear_set = false;
is_batch_set = false;
is_population_set = false;
is_generation_set = false;
Expand All @@ -82,6 +83,7 @@ void Parameters::set_default_parameters()
lambda_1 = lambda_2 = 5.0e-2f; // good default based on our tests
lambda_e = lambda_f = 1.0f; // energy and force are more important
lambda_v = 0.1f; // virial is less important
lambda_shear = 1.0f; // do not weight shear virial more by default
force_delta = 0.0f; // no modification of force loss
batch_size = 1000; // large enough in most cases
population_size = 50; // almost optimal
Expand Down Expand Up @@ -271,6 +273,12 @@ void Parameters::report_inputs()
printf(" (default) lambda_v = %g.\n", lambda_v);
}

if (is_lambda_shear_set) {
printf(" (input) lambda_shear = %g.\n", lambda_shear);
} else {
printf(" (default) lambda_shear = %g.\n", lambda_shear);
}

if (is_force_delta_set) {
printf(" (input) force_delta = %g.\n", force_delta);
} else {
Expand Down Expand Up @@ -346,6 +354,8 @@ void Parameters::parse_one_keyword(std::vector<std::string>& tokens)
parse_lambda_f(param, num_param);
} else if (strcmp(param[0], "lambda_v") == 0) {
parse_lambda_v(param, num_param);
} else if (strcmp(param[0], "lambda_shear") == 0) {
parse_lambda_shear(param, num_param);
} else if (strcmp(param[0], "type_weight") == 0) {
parse_type_weight(param, num_param);
} else if (strcmp(param[0], "force_delta") == 0) {
Expand Down Expand Up @@ -710,6 +720,25 @@ void Parameters::parse_lambda_v(const char** param, int num_param)
}
}

void Parameters::parse_lambda_shear(const char** param, int num_param)
{
is_lambda_shear_set = true;

if (num_param != 2) {
PRINT_INPUT_ERROR("lambda_shear should have 1 parameter.\n");
}

double lambda_shear_tmp = 0.0;
if (!is_valid_real(param[1], &lambda_shear_tmp)) {
PRINT_INPUT_ERROR("Shear virial weight should be a number.\n");
}
lambda_shear = lambda_shear_tmp;

if (lambda_shear < 0.0f) {
PRINT_INPUT_ERROR("Shear virial weight should >= 0.");
}
}

void Parameters::parse_batch(const char** param, int num_param)
{
is_batch_set = true;
Expand Down
3 changes: 3 additions & 0 deletions src/main_nep/parameters.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public:
float lambda_e; // weight parameter for energy RMSE loss
float lambda_f; // weight parameter for force RMSE loss
float lambda_v; // weight parameter for virial RMSE loss
float lambda_shear; // extra weight parameter for shear virial
float force_delta; // a parameters used to modify the force loss
bool enable_zbl; // true for inlcuding the universal ZBL potential
float zbl_rc_inner; // inner cutoff for the universal ZBL potential
Expand All @@ -66,6 +67,7 @@ public:
bool is_lambda_e_set;
bool is_lambda_f_set;
bool is_lambda_v_set;
bool is_lambda_shear_set;
bool is_batch_set;
bool is_population_set;
bool is_generation_set;
Expand Down Expand Up @@ -113,6 +115,7 @@ private:
void parse_lambda_e(const char** param, int num_param);
void parse_lambda_f(const char** param, int num_param);
void parse_lambda_v(const char** param, int num_param);
void parse_lambda_shear(const char** param, int num_param);
void parse_force_delta(const char** param, int num_param);
void parse_batch(const char** param, int num_param);
void parse_population(const char** param, int num_param);
Expand Down