Skip to content

Commit

Permalink
Rescore tb v5 (#2)
Browse files Browse the repository at this point in the history
* Make lc0 output v5 training data.

* Finish merge of v5 data into rescorer tb.

* Fixes for rescoring v4 data.

* Revert some unneeded formatting changes.
  • Loading branch information
Tilps authored Mar 22, 2020
1 parent 36399ac commit 3fe4716
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 131 deletions.
70 changes: 51 additions & 19 deletions src/neural/writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@ uint64_t ReverseBitsInBytes(uint64_t v) {
}
} // namespace

InputPlanes PlanesFromTrainingData(const V4TrainingData& data) {
InputPlanes PlanesFromTrainingData(const V5TrainingData& data) {
InputPlanes result;
if (data.input_format !=
pblczero::NetworkFormat::InputFormat::INPUT_CLASSICAL_112_PLANE) {
throw Exception("FRC input variant not yet supported.");
}
for (int i = 0; i < 104; i++) {
result.emplace_back();
result.back().mask = ReverseBitsInBytes(data.planes[i]);
Expand Down Expand Up @@ -82,32 +86,60 @@ TrainingDataReader::TrainingDataReader(std::string filename)

TrainingDataReader::~TrainingDataReader() { gzclose(fin_); }

bool TrainingDataReader::ReadChunk(V4TrainingData* data) {
if (format_v4) {
bool TrainingDataReader::ReadChunk(V5TrainingData* data) {
if (format_v5) {
int read_size = gzread(fin_, reinterpret_cast<void*>(data), sizeof(*data));
if (read_size < 0) throw Exception("Corrupt read.");
return read_size == sizeof(*data);
} else {
size_t v5_extra = 16;
size_t v4_extra = 16;
size_t v3_size = sizeof(*data) - v4_extra;
size_t v3_size = sizeof(*data) - v4_extra - v5_extra;
int read_size = gzread(fin_, reinterpret_cast<void*>(data), v3_size);
if (read_size < 0) throw Exception("Corrupt read.");
if (read_size != v3_size) return false;
if (data->version == 3) {
data->version = 4;
data->root_q = 0.0f;
data->best_q = 0.0f;
data->root_d = 0.0f;
data->best_d = 0.0f;
return true;
} else {
format_v4 = true;
read_size = gzread(
fin_,
reinterpret_cast<void*>(reinterpret_cast<char*>(data) + v3_size),
v4_extra);
if (read_size < 0) throw Exception("Corrupt read.");
return read_size == v4_extra;
auto orig_version = data->version;
switch (data->version) {
case 3: {
data->version = 4;
// First convert 3 to 4 to reduce code duplication.
char* v4_extra_start = reinterpret_cast<char*>(data) + v3_size;
// Write 0 bytes for 16 extra bytes - corresponding to 4 floats of 0.0f.
for (int i = 0; i < v4_extra; i++) {
v4_extra_start[i] = 0;
}
// Deliberate fallthrough.
}
case 4: {
// If actually 4, we need to read the additional data first.
if (orig_version == 4) {
read_size = gzread(
fin_,
reinterpret_cast<void*>(reinterpret_cast<char*>(data) + v3_size),
v4_extra);
if (read_size < 0) throw Exception("Corrupt read.");
if (read_size != v4_extra) return false;
}
data->version = 5;
char* data_ptr = reinterpret_cast<char*>(data);
// Shift data after version back 4 bytes.
memmove(data_ptr + 2 * sizeof(uint32_t), data_ptr + sizeof(uint32_t),
v3_size + v4_extra - sizeof(uint32_t));
data->input_format = pblczero::NetworkFormat::INPUT_CLASSICAL_112_PLANE;
data->root_m = 0.0f;
data->best_m = 0.0f;
data->plies_left = 0.0f;
return true;
}
case 5: {
format_v5 = true;
read_size = gzread(
fin_,
reinterpret_cast<void*>(reinterpret_cast<char*>(data) + v3_size),
v4_extra + v5_extra);
if (read_size < 0) throw Exception("Corrupt read.");
return read_size == v4_extra + v5_extra;
}
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/neural/writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ struct V5TrainingData {
} PACKED_STRUCT;
static_assert(sizeof(V5TrainingData) == 8308, "Wrong struct size");

InputPlanes PlanesFromTrainingData(const V4TrainingData& data);
InputPlanes PlanesFromTrainingData(const V5TrainingData& data);

#pragma pack(pop)

Expand Down Expand Up @@ -96,15 +96,15 @@ class TrainingDataReader {
~TrainingDataReader();

// Reads a chunk. Returns true if a chunk was read.
bool ReadChunk(V4TrainingData* data);
bool ReadChunk(V5TrainingData* data);

// Gets full filename of the file being read.
std::string GetFileName() const { return filename_; }

private:
std::string filename_;
gzFile fin_;
bool format_v4 = false;
bool format_v5 = false;
};

} // namespace lczero
Loading

0 comments on commit 3fe4716

Please sign in to comment.