Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fetch new validation data for each calc_loss #279

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions src/learn/learn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -598,9 +598,9 @@ namespace Learner

void learn_worker(Thread& th, std::atomic<uint64_t>& counter, uint64_t limit);

void update_weights(const PSVector& psv, uint64_t epoch);
void update_weights(uint64_t epoch);

void calc_loss(const PSVector& psv, uint64_t epoch);
void calc_loss(uint64_t epoch);

void calc_loss_worker(
Thread& th,
Expand Down Expand Up @@ -701,6 +701,16 @@ namespace Learner
});
mainThread->wait_for_worker_finished();

if (validation_data.size() != params.validation_count)
{
auto out = sync_region_cout.new_region();
out
<< "WARNING (learn): Error reading validation data. Read " << validation_data.size()
<< " out of " << params.validation_count << '\n'
<< "WARNING (learn): This either means that less than 1% of the validation data passed the filter"
<< " or the file is empty\n";
}

return validation_data;
}

Expand All @@ -714,23 +724,11 @@ namespace Learner

Eval::NNUE::verify_any_net_loaded();

const PSVector validation_data = fetch_next_validation_set();

if (validation_data.size() != params.validation_count)
{
auto out = sync_region_cout.new_region();
out
<< "INFO (learn): Error reading validation data. Read " << validation_data.size()
<< " out of " << params.validation_count << '\n'
<< "INFO (learn): This either means that less than 1% of the validation data passed the filter"
<< " or the file is empty\n";

return;
}
uint64_t epoch = 0;

if (params.newbob_decay != 1.0) {

calc_loss(validation_data, 0);
calc_loss(epoch++);

best_loss = latest_loss_sum / latest_loss_count;
latest_loss_sum = 0.0;
Expand All @@ -742,7 +740,7 @@ namespace Learner

stop_flag = false;

for(uint64_t epoch = 1; epoch <= epochs; ++epoch)
for(; epoch <= epochs; ++epoch)
{
std::atomic<uint64_t> counter{0};

Expand All @@ -757,7 +755,7 @@ namespace Learner
if (stop_flag)
break;

update_weights(validation_data, epoch);
update_weights(epoch);

if (stop_flag)
break;
Expand Down Expand Up @@ -860,7 +858,7 @@ namespace Learner
}
}

void LearnerThink::update_weights(const PSVector& psv, uint64_t epoch)
void LearnerThink::update_weights(uint64_t epoch)
{
// I'm not sure this fencing is correct. But either way there
// should be no real issues happening since
Expand All @@ -887,17 +885,19 @@ namespace Learner
loss_output_count = 0;

// loss calculation
calc_loss(psv, epoch);
calc_loss(epoch);

Eval::NNUE::check_health();
}
}

void LearnerThink::calc_loss(const PSVector& psv, uint64_t epoch)
void LearnerThink::calc_loss(uint64_t epoch)
{
TT.new_search();
TimePoint elapsed = now() - Search::Limits.startTime + 1;

const auto psv = fetch_next_validation_set();

auto out = sync_region_cout.new_region();

out << "\n";
Expand Down