diff --git a/src/learn/learn.cpp b/src/learn/learn.cpp index 8326ab24185..29a2b788ecc 100644 --- a/src/learn/learn.cpp +++ b/src/learn/learn.cpp @@ -598,9 +598,9 @@ namespace Learner void learn_worker(Thread& th, std::atomic& counter, uint64_t limit); - void update_weights(const PSVector& psv, uint64_t epoch); + void update_weights(uint64_t epoch); - void calc_loss(const PSVector& psv, uint64_t epoch); + void calc_loss(uint64_t epoch); void calc_loss_worker( Thread& th, @@ -701,6 +701,16 @@ namespace Learner }); mainThread->wait_for_worker_finished(); + if (validation_data.size() != params.validation_count) + { + auto out = sync_region_cout.new_region(); + out + << "WARNING (learn): Error reading validation data. Read " << validation_data.size() + << " out of " << params.validation_count << '\n' + << "WARNING (learn): This either means that less than 1% of the validation data passed the filter" + << " or the file is empty\n"; + } + return validation_data; } @@ -714,23 +724,11 @@ namespace Learner Eval::NNUE::verify_any_net_loaded(); - const PSVector validation_data = fetch_next_validation_set(); - - if (validation_data.size() != params.validation_count) - { - auto out = sync_region_cout.new_region(); - out - << "INFO (learn): Error reading validation data. Read " << validation_data.size() - << " out of " << params.validation_count << '\n' - << "INFO (learn): This either means that less than 1% of the validation data passed the filter" - << " or the file is empty\n"; - - return; - } + uint64_t epoch = 0; if (params.newbob_decay != 1.0) { - calc_loss(validation_data, 0); + calc_loss(epoch++); best_loss = latest_loss_sum / latest_loss_count; latest_loss_sum = 0.0; @@ -742,7 +740,7 @@ namespace Learner stop_flag = false; - for(uint64_t epoch = 1; epoch <= epochs; ++epoch) + for(; epoch <= epochs; ++epoch) { std::atomic counter{0}; @@ -757,7 +755,7 @@ namespace Learner if (stop_flag) break; - update_weights(validation_data, epoch); + update_weights(epoch); if (stop_flag) break; @@ -860,7 +858,7 @@ namespace Learner } } - void LearnerThink::update_weights(const PSVector& psv, uint64_t epoch) + void LearnerThink::update_weights(uint64_t epoch) { // I'm not sure this fencing is correct. But either way there // should be no real issues happening since @@ -887,17 +885,19 @@ namespace Learner loss_output_count = 0; // loss calculation - calc_loss(psv, epoch); + calc_loss(epoch); Eval::NNUE::check_health(); } } - void LearnerThink::calc_loss(const PSVector& psv, uint64_t epoch) + void LearnerThink::calc_loss(uint64_t epoch) { TT.new_search(); TimePoint elapsed = now() - Search::Limits.startTime + 1; + const auto psv = fetch_next_validation_set(); + auto out = sync_region_cout.new_region(); out << "\n";