-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hotfix for serializer #273
Changes from all commits
1690380
b9b8cdd
24e72d4
1056b88
561cb0e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,9 @@ inline void TestRoundTrip(treelite::Model* model) { | |
auto buffer = model->GetPyBuffer(); | ||
std::unique_ptr<treelite::Model> received_model = treelite::Model::CreateFromPyBuffer(buffer); | ||
|
||
ASSERT_EQ(TreeliteToBytes(model), TreeliteToBytes(received_model.get())); | ||
// Use ASSERT_TRUE, since ASSERT_EQ will dump all the raw bytes into a string, potentially | ||
// causing an OOM error | ||
ASSERT_TRUE(TreeliteToBytes(model) == TreeliteToBytes(received_model.get())); | ||
} | ||
|
||
for (int i = 0; i < 2; ++i) { | ||
|
@@ -44,7 +46,9 @@ inline void TestRoundTrip(treelite::Model* model) { | |
std::unique_ptr<treelite::Model> received_model = treelite::Model::DeserializeFromFile(fp); | ||
std::fclose(fp); | ||
|
||
ASSERT_EQ(TreeliteToBytes(model), TreeliteToBytes(received_model.get())); | ||
// Use ASSERT_TRUE, since ASSERT_EQ will dump all the raw bytes into a string, potentially | ||
// causing an OOM error | ||
ASSERT_TRUE(TreeliteToBytes(model) == TreeliteToBytes(received_model.get())); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a comment to this effect so that someone doesn't naively reintroduce this problem down the line? |
||
} | ||
} | ||
|
||
|
@@ -178,7 +182,7 @@ void PyBufferInterfaceRoundTrip_TreeDepth2() { | |
}; | ||
builder->SetModelParam("pred_transform", "sigmoid"); | ||
builder->SetModelParam("global_bias", "0.5"); | ||
for (int tree_id = 0; tree_id < 2; ++tree_id) { | ||
for (int tree_id = 0; tree_id < 3; ++tree_id) { | ||
std::unique_ptr<frontend::TreeBuilder> tree{ | ||
new frontend::TreeBuilder(threshold_type, leaf_output_type) | ||
}; | ||
|
@@ -189,10 +193,10 @@ void PyBufferInterfaceRoundTrip_TreeDepth2() { | |
tree->SetCategoricalTestNode(1, 0, {0, 1}, true, 3, 4); | ||
tree->SetCategoricalTestNode(2, 1, {0}, true, 5, 6); | ||
tree->SetRootNode(0); | ||
tree->SetLeafNode(3, frontend::Value::Create<LeafOutputType>(3)); | ||
tree->SetLeafNode(4, frontend::Value::Create<LeafOutputType>(1)); | ||
tree->SetLeafNode(5, frontend::Value::Create<LeafOutputType>(4)); | ||
tree->SetLeafNode(6, frontend::Value::Create<LeafOutputType>(2)); | ||
tree->SetLeafNode(3, frontend::Value::Create<LeafOutputType>(tree_id + 3)); | ||
tree->SetLeafNode(4, frontend::Value::Create<LeafOutputType>(tree_id + 1)); | ||
tree->SetLeafNode(5, frontend::Value::Create<LeafOutputType>(tree_id + 4)); | ||
tree->SetLeafNode(6, frontend::Value::Create<LeafOutputType>(tree_id + 2)); | ||
Comment on lines
+196
to
+199
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I crafted the test case so that all the trees will have different leaf outputs. The bug being addressed causes all trees to be identical. |
||
builder->InsertTree(tree.get()); | ||
} | ||
|
||
|
@@ -216,33 +220,35 @@ template <typename ThresholdType, typename LeafOutputType> | |
void PyBufferInterfaceRoundTrip_DeepFullTree() { | ||
TypeInfo threshold_type = TypeToInfo<ThresholdType>(); | ||
TypeInfo leaf_output_type = TypeToInfo<LeafOutputType>(); | ||
const int depth = 19; | ||
const int depth = 17; | ||
|
||
std::unique_ptr<frontend::ModelBuilder> builder{ | ||
new frontend::ModelBuilder(3, 1, false, threshold_type, leaf_output_type) | ||
}; | ||
std::unique_ptr<frontend::TreeBuilder> tree{ | ||
new frontend::TreeBuilder(threshold_type, leaf_output_type) | ||
}; | ||
for (int level = 0; level <= depth; ++level) { | ||
for (int i = 0; i < (1 << level); ++i) { | ||
const int nid = (1 << level) - 1 + i; | ||
tree->CreateNode(nid); | ||
for (int tree_id = 0; tree_id < 3; ++tree_id) { | ||
std::unique_ptr<frontend::TreeBuilder> tree{ | ||
new frontend::TreeBuilder(threshold_type, leaf_output_type) | ||
}; | ||
for (int level = 0; level <= depth; ++level) { | ||
for (int i = 0; i < (1 << level); ++i) { | ||
const int nid = (1 << level) - 1 + i; | ||
tree->CreateNode(nid); | ||
} | ||
} | ||
} | ||
for (int level = 0; level <= depth; ++level) { | ||
for (int i = 0; i < (1 << level); ++i) { | ||
const int nid = (1 << level) - 1 + i; | ||
if (level == depth) { | ||
tree->SetLeafNode(nid, frontend::Value::Create<LeafOutputType>(1)); | ||
} else { | ||
tree->SetNumericalTestNode(nid, (level % 2), "<", frontend::Value::Create<ThresholdType>(0), | ||
true, 2 * nid + 1, 2 * nid + 2); | ||
for (int level = 0; level <= depth; ++level) { | ||
for (int i = 0; i < (1 << level); ++i) { | ||
const int nid = (1 << level) - 1 + i; | ||
if (level == depth) { | ||
tree->SetLeafNode(nid, frontend::Value::Create<LeafOutputType>(tree_id + 1)); | ||
} else { | ||
tree->SetNumericalTestNode(nid, (level % 2), "<", frontend::Value::Create<ThresholdType>(0), | ||
true, 2 * nid + 1, 2 * nid + 2); | ||
} | ||
} | ||
} | ||
tree->SetRootNode(0); | ||
builder->InsertTree(tree.get()); | ||
} | ||
tree->SetRootNode(0); | ||
builder->InsertTree(tree.get()); | ||
|
||
std::unique_ptr<Model> model = builder->CommitModel(); | ||
TestRoundTrip(model.get()); | ||
|
@@ -255,4 +261,184 @@ TEST(PyBufferInterfaceRoundTrip, DeepFullTree) { | |
PyBufferInterfaceRoundTrip_DeepFullTree<double, uint32_t>(); | ||
} | ||
|
||
TEST(PyBufferInterfaceRoundTrip, XGBoostBoston) { | ||
std::unique_ptr<frontend::ModelBuilder> builder{ | ||
new frontend::ModelBuilder(13, 1, false, TypeInfo::kFloat32, TypeInfo::kFloat32) | ||
}; | ||
using frontend::Value; | ||
{ | ||
std::unique_ptr<frontend::TreeBuilder> tree{ | ||
new frontend::TreeBuilder(TypeInfo::kFloat32, TypeInfo::kFloat32) | ||
}; | ||
for (int nid = 0; nid <= 18; ++nid) { | ||
tree->CreateNode(nid); | ||
} | ||
tree->SetRootNode(0); | ||
tree->SetNumericalTestNode(0, 5, "<", Value::Create<float>(6.67599964f), true, 1, 2); | ||
tree->SetNumericalTestNode(1, 12, "<", Value::Create<float>(16.0849991f), true, 3, 4); | ||
tree->SetNumericalTestNode(3, 12, "<", Value::Create<float>(9.71500015f), true, 7, 8); | ||
tree->SetNumericalTestNode(7, 0, "<", Value::Create<float>(4.72704506f), true, 15, 16); | ||
tree->SetLeafNode(15, Value::Create<float>(23.1568813f)); | ||
tree->SetLeafNode(16, Value::Create<float>(37.125f)); | ||
tree->SetLeafNode(8, Value::Create<float>(19.625f)); | ||
tree->SetNumericalTestNode(4, 0, "<", Value::Create<float>(7.31708527f), true, 9, 10); | ||
tree->SetLeafNode(9, Value::Create<float>(16.0535717f)); | ||
tree->SetNumericalTestNode(10, 0, "<", Value::Create<float>(24.9239006f), true, 17, 18); | ||
tree->SetLeafNode(17, Value::Create<float>(9.96249962f)); | ||
tree->SetLeafNode(18, Value::Create<float>(3.20000005f)); | ||
tree->SetNumericalTestNode(2, 5, "<", Value::Create<float>(7.43700027f), true, 5, 6); | ||
tree->SetNumericalTestNode(5, 0, "<", Value::Create<float>(8.18112564f), true, 11, 12); | ||
tree->SetLeafNode(11, Value::Create<float>(31.0228577f)); | ||
tree->SetLeafNode(12, Value::Create<float>(9.30000019f)); | ||
tree->SetNumericalTestNode(6, 0, "<", Value::Create<float>(2.74223518f), true, 13, 14); | ||
tree->SetLeafNode(13, Value::Create<float>(43.8833313f)); | ||
tree->SetLeafNode(14, Value::Create<float>(10.6999998f)); | ||
builder->InsertTree(tree.get()); | ||
} | ||
{ | ||
std::unique_ptr<frontend::TreeBuilder> tree{ | ||
new frontend::TreeBuilder(TypeInfo::kFloat32, TypeInfo::kFloat32) | ||
}; | ||
for (int nid = 0; nid <= 18; ++nid) { | ||
tree->CreateNode(nid); | ||
} | ||
tree->SetRootNode(0); | ||
tree->SetNumericalTestNode(0, 7, "<", Value::Create<float>(1.37205005f), true, 1, 2); | ||
tree->SetNumericalTestNode(1, 12, "<", Value::Create<float>(11.4849997f), true, 3, 4); | ||
tree->SetLeafNode(3, Value::Create<float>(12.3465471f)); | ||
tree->SetNumericalTestNode(4, 0, "<", Value::Create<float>(6.57226992f), true, 7, 8); | ||
tree->SetLeafNode(7, Value::Create<float>(-2.63571453f)); | ||
tree->SetLeafNode(8, Value::Create<float>(3.13541675f)); | ||
tree->SetNumericalTestNode(2, 12, "<", Value::Create<float>(5.43999958f), true, 5, 6); | ||
tree->SetNumericalTestNode(5, 7, "<", Value::Create<float>(2.85050011f), true, 9, 10); | ||
tree->SetLeafNode(9, Value::Create<float>(6.06147718f)); | ||
tree->SetNumericalTestNode(10, 6, "<", Value::Create<float>(11.9499998f), true, 13, 14); | ||
tree->SetLeafNode(13, Value::Create<float>(-1.27356553f)); | ||
tree->SetLeafNode(14, Value::Create<float>(1.90286708f)); | ||
tree->SetNumericalTestNode(6, 9, "<", Value::Create<float>(278.0f), true, 11, 12); | ||
tree->SetNumericalTestNode(11, 7, "<", Value::Create<float>(4.36069965f), true, 15, 16); | ||
tree->SetLeafNode(15, Value::Create<float>(2.3283093f)); | ||
tree->SetLeafNode(16, Value::Create<float>(-0.740369797f)); | ||
tree->SetNumericalTestNode(12, 5, "<", Value::Create<float>(6.67599964f), true, 17, 18); | ||
tree->SetLeafNode(17, Value::Create<float>(-0.374256492f)); | ||
tree->SetLeafNode(18, Value::Create<float>(-3.15714335f)); | ||
builder->InsertTree(tree.get()); | ||
} | ||
{ | ||
std::unique_ptr<frontend::TreeBuilder> tree{ | ||
new frontend::TreeBuilder(TypeInfo::kFloat32, TypeInfo::kFloat32) | ||
}; | ||
for (int nid = 0; nid <= 30; ++nid) { | ||
tree->CreateNode(nid); | ||
} | ||
tree->SetRootNode(0); | ||
tree->SetNumericalTestNode(0, 10, "<", Value::Create<float>(20.9500008f), true, 1, 2); | ||
tree->SetNumericalTestNode(1, 7, "<", Value::Create<float>(6.22315025f), true, 3, 4); | ||
tree->SetNumericalTestNode(3, 8, "<", Value::Create<float>(2.5f), true, 7, 8); | ||
tree->SetNumericalTestNode(7, 6, "<", Value::Create<float>(62.7999992f), true, 15, 16); | ||
tree->SetLeafNode(15, Value::Create<float>(1.08206058f)); | ||
tree->SetLeafNode(16, Value::Create<float>(-2.5734961f)); | ||
tree->SetNumericalTestNode(8, 2, "<", Value::Create<float>(3.09500003f), true, 17, 18); | ||
tree->SetLeafNode(17, Value::Create<float>(2.1601212f)); | ||
tree->SetLeafNode(18, Value::Create<float>(0.435377121f)); | ||
tree->SetNumericalTestNode(4, 5, "<", Value::Create<float>(6.94500017f), true, 9, 10); | ||
tree->SetNumericalTestNode(9, 5, "<", Value::Create<float>(6.67650032f), true, 19, 20); | ||
tree->SetLeafNode(19, Value::Create<float>(-0.727042675f)); | ||
tree->SetLeafNode(20, Value::Create<float>(-2.75314689f)); | ||
tree->SetNumericalTestNode(10, 12, "<", Value::Create<float>(4.77499962f), true, 21, 22); | ||
tree->SetLeafNode(21, Value::Create<float>(-0.468363762f)); | ||
tree->SetLeafNode(22, Value::Create<float>(3.01290059f)); | ||
tree->SetNumericalTestNode(2, 5, "<", Value::Create<float>(6.11900043f), true, 5, 6); | ||
tree->SetNumericalTestNode(5, 9, "<", Value::Create<float>(385.5f), true, 11, 12); | ||
tree->SetNumericalTestNode(11, 5, "<", Value::Create<float>(5.51300001f), true, 23, 24); | ||
tree->SetLeafNode(23, Value::Create<float>(0.224628448f)); | ||
tree->SetLeafNode(24, Value::Create<float>(-2.32530594f)); | ||
tree->SetNumericalTestNode(12, 5, "<", Value::Create<float>(5.91499996f), true, 25, 26); | ||
tree->SetLeafNode(25, Value::Create<float>(-1.16225815f)); | ||
tree->SetLeafNode(26, Value::Create<float>(0.610342026f)); | ||
tree->SetNumericalTestNode(6, 0, "<", Value::Create<float>(0.448424995f), true, 13, 14); | ||
tree->SetNumericalTestNode(13, 5, "<", Value::Create<float>(6.45600033f), true, 27, 28); | ||
tree->SetLeafNode(27, Value::Create<float>(-1.64520073f)); | ||
tree->SetLeafNode(28, Value::Create<float>(-0.275371552f)); | ||
tree->SetNumericalTestNode(14, 0, "<", Value::Create<float>(0.681519985f), true, 29, 30); | ||
tree->SetLeafNode(29, Value::Create<float>(1.69765615f)); | ||
tree->SetLeafNode(30, Value::Create<float>(-0.246309474f)); | ||
builder->InsertTree(tree.get()); | ||
} | ||
{ | ||
std::unique_ptr<frontend::TreeBuilder> tree{ | ||
new frontend::TreeBuilder(TypeInfo::kFloat32, TypeInfo::kFloat32) | ||
}; | ||
for (int nid = 0; nid <= 26; ++nid) { | ||
tree->CreateNode(nid); | ||
} | ||
tree->SetRootNode(0); | ||
tree->SetNumericalTestNode(0, 5, "<", Value::Create<float>(6.68949986f), true, 1, 2); | ||
tree->SetNumericalTestNode(1, 5, "<", Value::Create<float>(6.5454998f), true, 3, 4); | ||
tree->SetNumericalTestNode(3, 9, "<", Value::Create<float>(207.5f), true, 7, 8); | ||
tree->SetNumericalTestNode(7, 11, "<", Value::Create<float>(377.880005f), true, 15, 16); | ||
tree->SetLeafNode(15, Value::Create<float>(0.200853109f)); | ||
tree->SetLeafNode(16, Value::Create<float>(3.14392781f)); | ||
tree->SetNumericalTestNode(8, 0, "<", Value::Create<float>(0.085769996f), true, 17, 18); | ||
tree->SetLeafNode(17, Value::Create<float>(-0.822109044f)); | ||
tree->SetLeafNode(18, Value::Create<float>(0.266653359f)); | ||
tree->SetNumericalTestNode(4, 5, "<", Value::Create<float>(6.56400013f), true, 9, 10); | ||
tree->SetLeafNode(9, Value::Create<float>(3.40855145f)); | ||
tree->SetNumericalTestNode(10, 7, "<", Value::Create<float>(3.29480004f), true, 19, 20); | ||
tree->SetLeafNode(19, Value::Create<float>(2.98598123f)); | ||
tree->SetLeafNode(20, Value::Create<float>(0.94572562f)); | ||
tree->SetNumericalTestNode(2, 7, "<", Value::Create<float>(1.95050001f), true, 5, 6); | ||
tree->SetNumericalTestNode(5, 5, "<", Value::Create<float>(6.86849976f), true, 11, 12); | ||
tree->SetLeafNode(11, Value::Create<float>(-0.0353970528f)); | ||
tree->SetNumericalTestNode(12, 0, "<", Value::Create<float>(0.943544984f), true, 21, 22); | ||
tree->SetLeafNode(21, Value::Create<float>(0.761680603f)); | ||
tree->SetLeafNode(22, Value::Create<float>(3.02160382f)); | ||
tree->SetNumericalTestNode(6, 6, "<", Value::Create<float>(50.75f), true, 13, 14); | ||
tree->SetNumericalTestNode(13, 10, "<", Value::Create<float>(15.5500002f), true, 23, 24); | ||
tree->SetLeafNode(23, Value::Create<float>(0.743751168f)); | ||
tree->SetLeafNode(24, Value::Create<float>(-0.792990744f)); | ||
tree->SetNumericalTestNode(14, 11, "<", Value::Create<float>(384.794983f), true, 25, 26); | ||
tree->SetLeafNode(25, Value::Create<float>(0.319963276f)); | ||
tree->SetLeafNode(26, Value::Create<float>(-2.88059473f)); | ||
builder->InsertTree(tree.get()); | ||
} | ||
{ | ||
std::unique_ptr<frontend::TreeBuilder> tree{ | ||
new frontend::TreeBuilder(TypeInfo::kFloat32, TypeInfo::kFloat32) | ||
}; | ||
for (int nid = 0; nid <= 24; ++nid) { | ||
tree->CreateNode(nid); | ||
} | ||
tree->SetRootNode(0); | ||
tree->SetNumericalTestNode(0, 4, "<", Value::Create<float>(0.820500016f), true, 1, 2); | ||
tree->SetNumericalTestNode(1, 10, "<", Value::Create<float>(17.7000008f), true, 3, 4); | ||
tree->SetNumericalTestNode(3, 5, "<", Value::Create<float>(6.5255003f), true, 7, 8); | ||
tree->SetNumericalTestNode(7, 0, "<", Value::Create<float>(0.0687299967f), true, 15, 16); | ||
tree->SetLeafNode(15, Value::Create<float>(0.206869483f)); | ||
tree->SetLeafNode(16, Value::Create<float>(1.80078018f)); | ||
tree->SetNumericalTestNode(8, 12, "<", Value::Create<float>(3.14499998f), true, 17, 18); | ||
tree->SetLeafNode(17, Value::Create<float>(-0.923567116f)); | ||
tree->SetLeafNode(18, Value::Create<float>(0.386075258f)); | ||
tree->SetNumericalTestNode(4, 0, "<", Value::Create<float>(0.0301299989f), true, 9, 10); | ||
tree->SetNumericalTestNode(9, 8, "<", Value::Create<float>(3.5f), true, 19, 20); | ||
tree->SetLeafNode(19, Value::Create<float>(1.81692481f)); | ||
tree->SetLeafNode(20, Value::Create<float>(0.130609035f)); | ||
tree->SetNumericalTestNode(10, 4, "<", Value::Create<float>(0.708999991f), true, 21, 22); | ||
tree->SetLeafNode(21, Value::Create<float>(-0.330942363f)); | ||
tree->SetLeafNode(22, Value::Create<float>(1.32937813f)); | ||
tree->SetNumericalTestNode(2, 11, "<", Value::Create<float>(347.565002f), true, 5, 6); | ||
tree->SetNumericalTestNode(5, 6, "<", Value::Create<float>(97.25f), true, 11, 12); | ||
tree->SetLeafNode(11, Value::Create<float>(-3.44793344f)); | ||
tree->SetLeafNode(12, Value::Create<float>(-1.2536478f)); | ||
tree->SetNumericalTestNode(6, 0, "<", Value::Create<float>(2.34980488f), true, 13, 14); | ||
tree->SetNumericalTestNode(13, 0, "<", Value::Create<float>(1.54080999f), true, 23, 24); | ||
tree->SetLeafNode(23, Value::Create<float>(-0.342328072f)); | ||
tree->SetLeafNode(24, Value::Create<float>(0.655293167f)); | ||
tree->SetLeafNode(14, Value::Create<float>(-1.51396859f)); | ||
builder->InsertTree(tree.get()); | ||
} | ||
std::unique_ptr<Model> model = builder->CommitModel(); | ||
TestRoundTrip(model.get()); | ||
} | ||
|
||
} // namespace treelite |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're capturing
begin
by reference, won't this iterate begin forward as a side effect of executing the lambda? It's not immediately clear to me why that is the correct thing to do with this handler. Maybe add a comment or refactor such that the lambda returns the next iterator to be processed rather than directly changingbegin
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is intended behavior.
That would not be feasible, since the tree handler for the file stream does not involve iterators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Can we get a comment just to make it clear to future developers that the side-effect is intentional?