From 1690380b615bd36cd21bbbcd6306450fdef80bd2 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 6 May 2021 15:13:06 -0700 Subject: [PATCH 1/5] Hotfix for serializer --- include/treelite/tree_impl.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/treelite/tree_impl.h b/include/treelite/tree_impl.h index 5868fd02..90001c9e 100644 --- a/include/treelite/tree_impl.h +++ b/include/treelite/tree_impl.h @@ -925,6 +925,7 @@ ModelImpl::InitFromPyBuffer( auto tree_hanlder = [&begin](Tree& tree) { tree.InitFromPyBuffer(begin, begin + kNumFramePerTree); + begin += kNumFramePerTree; }; DeserializeTemplate(num_tree, header_field_handler, tree_hanlder); From b9b8cdd53fc94fa72fe372395a14ec9de2e7195e Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 6 May 2021 15:32:14 -0700 Subject: [PATCH 2/5] Add test cases to demonstrate the serializer bug --- tests/cpp/test_serializer.cc | 52 +++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/tests/cpp/test_serializer.cc b/tests/cpp/test_serializer.cc index dfb156c0..4d4ede05 100644 --- a/tests/cpp/test_serializer.cc +++ b/tests/cpp/test_serializer.cc @@ -29,7 +29,7 @@ inline void TestRoundTrip(treelite::Model* model) { auto buffer = model->GetPyBuffer(); std::unique_ptr received_model = treelite::Model::CreateFromPyBuffer(buffer); - ASSERT_EQ(TreeliteToBytes(model), TreeliteToBytes(received_model.get())); + ASSERT_TRUE(TreeliteToBytes(model) == TreeliteToBytes(received_model.get())); } for (int i = 0; i < 2; ++i) { @@ -44,7 +44,7 @@ inline void TestRoundTrip(treelite::Model* model) { std::unique_ptr received_model = treelite::Model::DeserializeFromFile(fp); std::fclose(fp); - ASSERT_EQ(TreeliteToBytes(model), TreeliteToBytes(received_model.get())); + ASSERT_TRUE(TreeliteToBytes(model) == TreeliteToBytes(received_model.get())); } } @@ -178,7 +178,7 @@ void PyBufferInterfaceRoundTrip_TreeDepth2() { }; builder->SetModelParam("pred_transform", "sigmoid"); builder->SetModelParam("global_bias", "0.5"); - for (int tree_id = 0; tree_id < 2; ++tree_id) { + for (int tree_id = 0; tree_id < 3; ++tree_id) { std::unique_ptr tree{ new frontend::TreeBuilder(threshold_type, leaf_output_type) }; @@ -189,10 +189,10 @@ void PyBufferInterfaceRoundTrip_TreeDepth2() { tree->SetCategoricalTestNode(1, 0, {0, 1}, true, 3, 4); tree->SetCategoricalTestNode(2, 1, {0}, true, 5, 6); tree->SetRootNode(0); - tree->SetLeafNode(3, frontend::Value::Create(3)); - tree->SetLeafNode(4, frontend::Value::Create(1)); - tree->SetLeafNode(5, frontend::Value::Create(4)); - tree->SetLeafNode(6, frontend::Value::Create(2)); + tree->SetLeafNode(3, frontend::Value::Create(tree_id + 3)); + tree->SetLeafNode(4, frontend::Value::Create(tree_id + 1)); + tree->SetLeafNode(5, frontend::Value::Create(tree_id + 4)); + tree->SetLeafNode(6, frontend::Value::Create(tree_id + 2)); builder->InsertTree(tree.get()); } @@ -221,28 +221,30 @@ void PyBufferInterfaceRoundTrip_DeepFullTree() { std::unique_ptr builder{ new frontend::ModelBuilder(3, 1, false, threshold_type, leaf_output_type) }; - std::unique_ptr tree{ - new frontend::TreeBuilder(threshold_type, leaf_output_type) - }; - for (int level = 0; level <= depth; ++level) { - for (int i = 0; i < (1 << level); ++i) { - const int nid = (1 << level) - 1 + i; - tree->CreateNode(nid); + for (int tree_id = 0; tree_id < 3; ++tree_id) { + std::unique_ptr tree{ + new frontend::TreeBuilder(threshold_type, leaf_output_type) + }; + for (int level = 0; level <= depth; ++level) { + for (int i = 0; i < (1 << level); ++i) { + const int nid = (1 << level) - 1 + i; + tree->CreateNode(nid); + } } - } - for (int level = 0; level <= depth; ++level) { - for (int i = 0; i < (1 << level); ++i) { - const int nid = (1 << level) - 1 + i; - if (level == depth) { - tree->SetLeafNode(nid, frontend::Value::Create(1)); - } else { - tree->SetNumericalTestNode(nid, (level % 2), "<", frontend::Value::Create(0), - true, 2 * nid + 1, 2 * nid + 2); + for (int level = 0; level <= depth; ++level) { + for (int i = 0; i < (1 << level); ++i) { + const int nid = (1 << level) - 1 + i; + if (level == depth) { + tree->SetLeafNode(nid, frontend::Value::Create(tree_id + 1)); + } else { + tree->SetNumericalTestNode(nid, (level % 2), "<", frontend::Value::Create(0), + true, 2 * nid + 1, 2 * nid + 2); + } } } + tree->SetRootNode(0); + builder->InsertTree(tree.get()); } - tree->SetRootNode(0); - builder->InsertTree(tree.get()); std::unique_ptr model = builder->CommitModel(); TestRoundTrip(model.get()); From 24e72d4c55907334381fc7e1d4175f0023ef3b91 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 6 May 2021 22:15:44 -0700 Subject: [PATCH 3/5] Reduce amount of time for PyBufferInterfaceRoundTrip.DeepFullTree --- tests/cpp/test_serializer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_serializer.cc b/tests/cpp/test_serializer.cc index 4d4ede05..bfe35000 100644 --- a/tests/cpp/test_serializer.cc +++ b/tests/cpp/test_serializer.cc @@ -216,7 +216,7 @@ template void PyBufferInterfaceRoundTrip_DeepFullTree() { TypeInfo threshold_type = TypeToInfo(); TypeInfo leaf_output_type = TypeToInfo(); - const int depth = 19; + const int depth = 17; std::unique_ptr builder{ new frontend::ModelBuilder(3, 1, false, threshold_type, leaf_output_type) From 1056b88eb4ba6b3ca5119273b69c2f7c3ba7f947 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 6 May 2021 22:16:02 -0700 Subject: [PATCH 4/5] Add case PyBufferInterfaceRoundTrip.XGBoostBoston --- tests/cpp/test_serializer.cc | 180 +++++++++++++++++++++++++++++++++++ 1 file changed, 180 insertions(+) diff --git a/tests/cpp/test_serializer.cc b/tests/cpp/test_serializer.cc index bfe35000..73742e04 100644 --- a/tests/cpp/test_serializer.cc +++ b/tests/cpp/test_serializer.cc @@ -257,4 +257,184 @@ TEST(PyBufferInterfaceRoundTrip, DeepFullTree) { PyBufferInterfaceRoundTrip_DeepFullTree(); } +TEST(PyBufferInterfaceRoundTrip, XGBoostBoston) { + std::unique_ptr builder{ + new frontend::ModelBuilder(13, 1, false, TypeInfo::kFloat32, TypeInfo::kFloat32) + }; + using frontend::Value; + { + std::unique_ptr tree{ + new frontend::TreeBuilder(TypeInfo::kFloat32, TypeInfo::kFloat32) + }; + for (int nid = 0; nid <= 18; ++nid) { + tree->CreateNode(nid); + } + tree->SetRootNode(0); + tree->SetNumericalTestNode(0, 5, "<", Value::Create(6.67599964f), true, 1, 2); + tree->SetNumericalTestNode(1, 12, "<", Value::Create(16.0849991f), true, 3, 4); + tree->SetNumericalTestNode(3, 12, "<", Value::Create(9.71500015f), true, 7, 8); + tree->SetNumericalTestNode(7, 0, "<", Value::Create(4.72704506f), true, 15, 16); + tree->SetLeafNode(15, Value::Create(23.1568813f)); + tree->SetLeafNode(16, Value::Create(37.125f)); + tree->SetLeafNode(8, Value::Create(19.625f)); + tree->SetNumericalTestNode(4, 0, "<", Value::Create(7.31708527f), true, 9, 10); + tree->SetLeafNode(9, Value::Create(16.0535717f)); + tree->SetNumericalTestNode(10, 0, "<", Value::Create(24.9239006f), true, 17, 18); + tree->SetLeafNode(17, Value::Create(9.96249962f)); + tree->SetLeafNode(18, Value::Create(3.20000005f)); + tree->SetNumericalTestNode(2, 5, "<", Value::Create(7.43700027f), true, 5, 6); + tree->SetNumericalTestNode(5, 0, "<", Value::Create(8.18112564f), true, 11, 12); + tree->SetLeafNode(11, Value::Create(31.0228577f)); + tree->SetLeafNode(12, Value::Create(9.30000019f)); + tree->SetNumericalTestNode(6, 0, "<", Value::Create(2.74223518f), true, 13, 14); + tree->SetLeafNode(13, Value::Create(43.8833313f)); + tree->SetLeafNode(14, Value::Create(10.6999998f)); + builder->InsertTree(tree.get()); + } + { + std::unique_ptr tree{ + new frontend::TreeBuilder(TypeInfo::kFloat32, TypeInfo::kFloat32) + }; + for (int nid = 0; nid <= 18; ++nid) { + tree->CreateNode(nid); + } + tree->SetRootNode(0); + tree->SetNumericalTestNode(0, 7, "<", Value::Create(1.37205005f), true, 1, 2); + tree->SetNumericalTestNode(1, 12, "<", Value::Create(11.4849997f), true, 3, 4); + tree->SetLeafNode(3, Value::Create(12.3465471f)); + tree->SetNumericalTestNode(4, 0, "<", Value::Create(6.57226992f), true, 7, 8); + tree->SetLeafNode(7, Value::Create(-2.63571453f)); + tree->SetLeafNode(8, Value::Create(3.13541675f)); + tree->SetNumericalTestNode(2, 12, "<", Value::Create(5.43999958f), true, 5, 6); + tree->SetNumericalTestNode(5, 7, "<", Value::Create(2.85050011f), true, 9, 10); + tree->SetLeafNode(9, Value::Create(6.06147718f)); + tree->SetNumericalTestNode(10, 6, "<", Value::Create(11.9499998f), true, 13, 14); + tree->SetLeafNode(13, Value::Create(-1.27356553f)); + tree->SetLeafNode(14, Value::Create(1.90286708f)); + tree->SetNumericalTestNode(6, 9, "<", Value::Create(278.0f), true, 11, 12); + tree->SetNumericalTestNode(11, 7, "<", Value::Create(4.36069965f), true, 15, 16); + tree->SetLeafNode(15, Value::Create(2.3283093f)); + tree->SetLeafNode(16, Value::Create(-0.740369797f)); + tree->SetNumericalTestNode(12, 5, "<", Value::Create(6.67599964f), true, 17, 18); + tree->SetLeafNode(17, Value::Create(-0.374256492f)); + tree->SetLeafNode(18, Value::Create(-3.15714335f)); + builder->InsertTree(tree.get()); + } + { + std::unique_ptr tree{ + new frontend::TreeBuilder(TypeInfo::kFloat32, TypeInfo::kFloat32) + }; + for (int nid = 0; nid <= 30; ++nid) { + tree->CreateNode(nid); + } + tree->SetRootNode(0); + tree->SetNumericalTestNode(0, 10, "<", Value::Create(20.9500008f), true, 1, 2); + tree->SetNumericalTestNode(1, 7, "<", Value::Create(6.22315025f), true, 3, 4); + tree->SetNumericalTestNode(3, 8, "<", Value::Create(2.5f), true, 7, 8); + tree->SetNumericalTestNode(7, 6, "<", Value::Create(62.7999992f), true, 15, 16); + tree->SetLeafNode(15, Value::Create(1.08206058f)); + tree->SetLeafNode(16, Value::Create(-2.5734961f)); + tree->SetNumericalTestNode(8, 2, "<", Value::Create(3.09500003f), true, 17, 18); + tree->SetLeafNode(17, Value::Create(2.1601212f)); + tree->SetLeafNode(18, Value::Create(0.435377121f)); + tree->SetNumericalTestNode(4, 5, "<", Value::Create(6.94500017f), true, 9, 10); + tree->SetNumericalTestNode(9, 5, "<", Value::Create(6.67650032f), true, 19, 20); + tree->SetLeafNode(19, Value::Create(-0.727042675f)); + tree->SetLeafNode(20, Value::Create(-2.75314689f)); + tree->SetNumericalTestNode(10, 12, "<", Value::Create(4.77499962f), true, 21, 22); + tree->SetLeafNode(21, Value::Create(-0.468363762f)); + tree->SetLeafNode(22, Value::Create(3.01290059f)); + tree->SetNumericalTestNode(2, 5, "<", Value::Create(6.11900043f), true, 5, 6); + tree->SetNumericalTestNode(5, 9, "<", Value::Create(385.5f), true, 11, 12); + tree->SetNumericalTestNode(11, 5, "<", Value::Create(5.51300001f), true, 23, 24); + tree->SetLeafNode(23, Value::Create(0.224628448f)); + tree->SetLeafNode(24, Value::Create(-2.32530594f)); + tree->SetNumericalTestNode(12, 5, "<", Value::Create(5.91499996f), true, 25, 26); + tree->SetLeafNode(25, Value::Create(-1.16225815f)); + tree->SetLeafNode(26, Value::Create(0.610342026f)); + tree->SetNumericalTestNode(6, 0, "<", Value::Create(0.448424995f), true, 13, 14); + tree->SetNumericalTestNode(13, 5, "<", Value::Create(6.45600033f), true, 27, 28); + tree->SetLeafNode(27, Value::Create(-1.64520073f)); + tree->SetLeafNode(28, Value::Create(-0.275371552f)); + tree->SetNumericalTestNode(14, 0, "<", Value::Create(0.681519985f), true, 29, 30); + tree->SetLeafNode(29, Value::Create(1.69765615f)); + tree->SetLeafNode(30, Value::Create(-0.246309474f)); + builder->InsertTree(tree.get()); + } + { + std::unique_ptr tree{ + new frontend::TreeBuilder(TypeInfo::kFloat32, TypeInfo::kFloat32) + }; + for (int nid = 0; nid <= 26; ++nid) { + tree->CreateNode(nid); + } + tree->SetRootNode(0); + tree->SetNumericalTestNode(0, 5, "<", Value::Create(6.68949986f), true, 1, 2); + tree->SetNumericalTestNode(1, 5, "<", Value::Create(6.5454998f), true, 3, 4); + tree->SetNumericalTestNode(3, 9, "<", Value::Create(207.5f), true, 7, 8); + tree->SetNumericalTestNode(7, 11, "<", Value::Create(377.880005f), true, 15, 16); + tree->SetLeafNode(15, Value::Create(0.200853109f)); + tree->SetLeafNode(16, Value::Create(3.14392781f)); + tree->SetNumericalTestNode(8, 0, "<", Value::Create(0.085769996f), true, 17, 18); + tree->SetLeafNode(17, Value::Create(-0.822109044f)); + tree->SetLeafNode(18, Value::Create(0.266653359f)); + tree->SetNumericalTestNode(4, 5, "<", Value::Create(6.56400013f), true, 9, 10); + tree->SetLeafNode(9, Value::Create(3.40855145f)); + tree->SetNumericalTestNode(10, 7, "<", Value::Create(3.29480004f), true, 19, 20); + tree->SetLeafNode(19, Value::Create(2.98598123f)); + tree->SetLeafNode(20, Value::Create(0.94572562f)); + tree->SetNumericalTestNode(2, 7, "<", Value::Create(1.95050001f), true, 5, 6); + tree->SetNumericalTestNode(5, 5, "<", Value::Create(6.86849976f), true, 11, 12); + tree->SetLeafNode(11, Value::Create(-0.0353970528f)); + tree->SetNumericalTestNode(12, 0, "<", Value::Create(0.943544984f), true, 21, 22); + tree->SetLeafNode(21, Value::Create(0.761680603f)); + tree->SetLeafNode(22, Value::Create(3.02160382f)); + tree->SetNumericalTestNode(6, 6, "<", Value::Create(50.75f), true, 13, 14); + tree->SetNumericalTestNode(13, 10, "<", Value::Create(15.5500002f), true, 23, 24); + tree->SetLeafNode(23, Value::Create(0.743751168f)); + tree->SetLeafNode(24, Value::Create(-0.792990744f)); + tree->SetNumericalTestNode(14, 11, "<", Value::Create(384.794983f), true, 25, 26); + tree->SetLeafNode(25, Value::Create(0.319963276f)); + tree->SetLeafNode(26, Value::Create(-2.88059473f)); + builder->InsertTree(tree.get()); + } + { + std::unique_ptr tree{ + new frontend::TreeBuilder(TypeInfo::kFloat32, TypeInfo::kFloat32) + }; + for (int nid = 0; nid <= 24; ++nid) { + tree->CreateNode(nid); + } + tree->SetRootNode(0); + tree->SetNumericalTestNode(0, 4, "<", Value::Create(0.820500016f), true, 1, 2); + tree->SetNumericalTestNode(1, 10, "<", Value::Create(17.7000008f), true, 3, 4); + tree->SetNumericalTestNode(3, 5, "<", Value::Create(6.5255003f), true, 7, 8); + tree->SetNumericalTestNode(7, 0, "<", Value::Create(0.0687299967f), true, 15, 16); + tree->SetLeafNode(15, Value::Create(0.206869483f)); + tree->SetLeafNode(16, Value::Create(1.80078018f)); + tree->SetNumericalTestNode(8, 12, "<", Value::Create(3.14499998f), true, 17, 18); + tree->SetLeafNode(17, Value::Create(-0.923567116f)); + tree->SetLeafNode(18, Value::Create(0.386075258f)); + tree->SetNumericalTestNode(4, 0, "<", Value::Create(0.0301299989f), true, 9, 10); + tree->SetNumericalTestNode(9, 8, "<", Value::Create(3.5f), true, 19, 20); + tree->SetLeafNode(19, Value::Create(1.81692481f)); + tree->SetLeafNode(20, Value::Create(0.130609035f)); + tree->SetNumericalTestNode(10, 4, "<", Value::Create(0.708999991f), true, 21, 22); + tree->SetLeafNode(21, Value::Create(-0.330942363f)); + tree->SetLeafNode(22, Value::Create(1.32937813f)); + tree->SetNumericalTestNode(2, 11, "<", Value::Create(347.565002f), true, 5, 6); + tree->SetNumericalTestNode(5, 6, "<", Value::Create(97.25f), true, 11, 12); + tree->SetLeafNode(11, Value::Create(-3.44793344f)); + tree->SetLeafNode(12, Value::Create(-1.2536478f)); + tree->SetNumericalTestNode(6, 0, "<", Value::Create(2.34980488f), true, 13, 14); + tree->SetNumericalTestNode(13, 0, "<", Value::Create(1.54080999f), true, 23, 24); + tree->SetLeafNode(23, Value::Create(-0.342328072f)); + tree->SetLeafNode(24, Value::Create(0.655293167f)); + tree->SetLeafNode(14, Value::Create(-1.51396859f)); + builder->InsertTree(tree.get()); + } + std::unique_ptr model = builder->CommitModel(); + TestRoundTrip(model.get()); +} + } // namespace treelite From 561cb0e8a787c72d16f57e4c0458ba9a7a043038 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Mon, 10 May 2021 17:52:23 -0700 Subject: [PATCH 5/5] Address reviewer's feedback --- include/treelite/tree_impl.h | 16 +++++++++------- tests/cpp/test_serializer.cc | 4 ++++ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/include/treelite/tree_impl.h b/include/treelite/tree_impl.h index 90001c9e..fcfa586c 100644 --- a/include/treelite/tree_impl.h +++ b/include/treelite/tree_impl.h @@ -859,7 +859,7 @@ template template inline void ModelImpl::DeserializeTemplate( - size_t num_tree, + std::size_t num_tree, HeaderFieldHandlerFunc header_field_handler, TreeHandlerFunc tree_handler) { /* Header */ @@ -870,7 +870,7 @@ ModelImpl::DeserializeTemplate( header_field_handler(¶m); /* Body */ trees.clear(); - for (size_t i = 0; i < num_tree; ++i) { + for (std::size_t i = 0; i < num_tree; ++i) { trees.emplace_back(); tree_handler(trees.back()); } @@ -917,18 +917,20 @@ ModelImpl::InitFromPyBuffer( if (num_frame < kNumFrameInHeader || (num_frame - kNumFrameInHeader) % kNumFramePerTree != 0) { throw std::runtime_error("Wrong number of frames"); } - const size_t num_tree = (num_frame - kNumFrameInHeader) / kNumFramePerTree; + const std::size_t num_tree = (num_frame - kNumFrameInHeader) / kNumFramePerTree; auto header_field_handler = [&begin](auto* field) { InitScalarFromPyBuffer(field, *begin++); }; - auto tree_hanlder = [&begin](Tree& tree) { + auto tree_handler = [&begin](Tree& tree) { + // Read the frames in the range [begin, begin + kNumFramePerTree) into the tree tree.InitFromPyBuffer(begin, begin + kNumFramePerTree); begin += kNumFramePerTree; + // Advance the iterator so that the next tree reads the next kNumFramePerTree frames }; - DeserializeTemplate(num_tree, header_field_handler, tree_hanlder); + DeserializeTemplate(num_tree, header_field_handler, tree_handler); } template @@ -941,11 +943,11 @@ ModelImpl::DeserializeFromFileImpl(FILE* src_fp) ReadScalarFromFile(field, src_fp); }; - auto tree_hanlder = [src_fp](Tree& tree) { + auto tree_handler = [src_fp](Tree& tree) { tree.DeserializeFromFile(src_fp); }; - DeserializeTemplate(num_tree, header_field_handler, tree_hanlder); + DeserializeTemplate(num_tree, header_field_handler, tree_handler); } inline void InitParamAndCheck(ModelParam* param, diff --git a/tests/cpp/test_serializer.cc b/tests/cpp/test_serializer.cc index 73742e04..257c8ce2 100644 --- a/tests/cpp/test_serializer.cc +++ b/tests/cpp/test_serializer.cc @@ -29,6 +29,8 @@ inline void TestRoundTrip(treelite::Model* model) { auto buffer = model->GetPyBuffer(); std::unique_ptr received_model = treelite::Model::CreateFromPyBuffer(buffer); + // Use ASSERT_TRUE, since ASSERT_EQ will dump all the raw bytes into a string, potentially + // causing an OOM error ASSERT_TRUE(TreeliteToBytes(model) == TreeliteToBytes(received_model.get())); } @@ -44,6 +46,8 @@ inline void TestRoundTrip(treelite::Model* model) { std::unique_ptr received_model = treelite::Model::DeserializeFromFile(fp); std::fclose(fp); + // Use ASSERT_TRUE, since ASSERT_EQ will dump all the raw bytes into a string, potentially + // causing an OOM error ASSERT_TRUE(TreeliteToBytes(model) == TreeliteToBytes(received_model.get())); } }