From 2c438e01c5a6235cbedebdbc7c27faabafa16872 Mon Sep 17 00:00:00 2001 From: gAldeia Date: Sun, 14 Apr 2024 20:27:00 -0300 Subject: [PATCH] Improving PTC2 to work with strong typed programs. Fixed infinite loop in PTC2 --- src/data/io.cpp | 7 ++--- src/search_space.cpp | 54 +++++++++++++++++++++----------------- tests/cpp/test_program.cpp | 14 ++++++++-- 3 files changed, 46 insertions(+), 29 deletions(-) diff --git a/src/data/io.cpp b/src/data/io.cpp index 8293f478..d81559ae 100755 --- a/src/data/io.cpp +++ b/src/data/io.cpp @@ -81,9 +81,10 @@ Dataset read_csv ( // check if endpoint is binary bool binary_endpoint = (y.array() == 0 || y.array() == 1).all(); - auto result = Dataset(features,y,binary_endpoint); - return result; - + // using constructor 1. (initializing data from a map) + auto result = Dataset(features, y, binary_endpoint); + + return result; } } // Brush diff --git a/src/search_space.cpp b/src/search_space.cpp index 5fc8621e..6aaf45cb 100644 --- a/src/search_space.cpp +++ b/src/search_space.cpp @@ -267,14 +267,14 @@ tree SearchSpace::PTC2(Node root, int max_d, int max_size) const auto Tree = tree(); - /* fmt::print("building program with max size {}, max depth {}",max_size,max_d); */ + fmt::print("building program with max size {}, max depth {}",max_size,max_d); // Queue of nodes that need children vector> queue; - /* cout << "chose " << n.name << endl; */ + cout << "root " << root.name << endl; // auto spot = Tree.set_head(n); - /* cout << "inserting...\n"; */ + cout << "inserting...\n"; auto spot = Tree.insert(Tree.begin(), root); // node depth @@ -295,7 +295,7 @@ tree SearchSpace::PTC2(Node root, int max_d, int max_size) const //For each argument position a of n, Enqueue(a; g) for (auto a : root.arg_types) { - /* cout << "queing a node of type " << DataTypeName[a] << endl; */ + cout << "queing a node of type " << DataTypeName[a] << endl; auto child_spot = Tree.append_child(spot); queue.push_back(make_tuple(child_spot, a, d)); } @@ -304,8 +304,8 @@ tree SearchSpace::PTC2(Node root, int max_d, int max_size) const Node n; // Now we actually start the PTC2 procedure to create the program tree - /* cout << "queue size: " << queue.size() << endl; */ - /* cout << "entering first while loop...\n"; */ + cout << "queue size: " << queue.size() << endl; + cout << "entering first while loop...\n"; while ( queue.size() + s < max_size && queue.size() > 0) { // including the queue size in the max_size, since each element in queue @@ -317,14 +317,14 @@ tree SearchSpace::PTC2(Node root, int max_d, int max_size) const // always insert a non terminal (which by default has weights off). // this way, we can have PTC2 working properly. - /* cout << "queue size: " << queue.size() << endl; */ + cout << "queue size: " << queue.size() << endl; auto [qspot, t, d] = RandomDequeue(queue); - /* cout << "current depth: " << d << endl; */ + cout << "current depth: " << d << endl; if (d >= max_d || s >= max_size) { // choose terminal of matching type - /* cout << "getting " << DataTypeName[t] << " terminal\n"; */ + cout << "getting " << DataTypeName[t] << " terminal\n"; // qspot = sample_terminal(t); // Tree.replace(qspot, sample_terminal(t)); // Tree.append_child(qspot, sample_terminal(t)); @@ -344,15 +344,22 @@ tree SearchSpace::PTC2(Node root, int max_d, int max_size) const else { //choose a nonterminal of matching type - /* cout << "getting op of type " << DataTypeName[t] << endl; */ + cout << "getting op of type " << DataTypeName[t] << endl; auto opt = sample_op(t); - /* cout << "chose " << n.name << endl; */ + cout << "chose " << n.name << endl; // TreeIter new_spot = Tree.append_child(qspot, n); // qspot = n; - if (!opt) { - queue.push_back(make_tuple(qspot, t, d)); - continue; + if (!opt) { // there is no operator for this node. sample a terminal instead + opt = sample_terminal(t); + } + + if (!opt) { // no operator nor terminal. weird. + auto msg = fmt::format("Failed to sample operator AND terminal of data type {} during PTC2.\n", DataTypeName[t]); + HANDLE_ERROR_THROW(msg); + + // queue.push_back(make_tuple(qspot, t, d)); + // continue; } n = opt.value(); @@ -362,7 +369,7 @@ tree SearchSpace::PTC2(Node root, int max_d, int max_size) const // For each arg of n, add to queue for (auto a : n.arg_types) { - /* cout << "queing a node of type " << DataTypeName[a] << endl; */ + cout << "queing a node of type " << DataTypeName[a] << endl; // queue.push_back(make_tuple(new_spot, a, d+1)); auto child_spot = Tree.append_child(newspot); @@ -382,19 +389,20 @@ tree SearchSpace::PTC2(Node root, int max_d, int max_size) const && Isnt(n.node_type) ) s += 2; - /* cout << "current tree size: " << s << endl; */ + cout << "current tree size: " << s << endl; } - /* cout << "entering second while loop...\n"; */ + + cout << "entering second while loop...\n"; while (queue.size() > 0) { if (queue.size() == 0) break; - /* cout << "queue size: " << queue.size() << endl; */ + cout << "queue size: " << queue.size() << endl; auto [qspot, t, d] = RandomDequeue(queue); - /* cout << "getting " << DataTypeName[t] << " terminal\n"; */ + cout << "getting " << DataTypeName[t] << " terminal\n"; // Tree.append_child(qspot, sample_terminal(t)); // qspot = sample_terminal(t); // auto newspot = Tree.replace(qspot, sample_terminal(t)); @@ -408,11 +416,9 @@ tree SearchSpace::PTC2(Node root, int max_d, int max_size) const auto newspot = Tree.replace(qspot, n); } - /* cout << "final tree:\n" */ - /* << Tree.begin().node->get_model() << "\n" */ - /* << Tree.begin().node->get_tree_model(true) << endl; */ - /* << Tree.get_model() << "\n" */ - /* << Tree.get_model(true) << endl; // pretty */ + cout << "final tree:\n" + << Tree.begin().node->get_model() << "\n" + << Tree.begin().node->get_tree_model(true) << endl; return Tree; }; diff --git a/tests/cpp/test_program.cpp b/tests/cpp/test_program.cpp index 69d5819b..e838d60d 100644 --- a/tests/cpp/test_program.cpp +++ b/tests/cpp/test_program.cpp @@ -97,6 +97,8 @@ TEST(Program, PredictWithWeights) Dataset data = Data::read_csv("docs/examples/datasets/d_enc.csv","label"); + ASSERT_FALSE(data.classification); + SearchSpace SS; SS.init(data); @@ -138,16 +140,20 @@ TEST(Program, FitClassifier) { Parameters params; - Dataset data = Data::read_csv("docs/examples/datasets/d_analcatdata_aids.csv","target"); + Dataset data = Data::read_csv("docs/examples/datasets/d_analcatdata_aids.csv", "target"); + + ASSERT_TRUE(data.classification); SearchSpace SS; + SS.init(data); for (int d = 1; d < 10; ++d) { for (int s = 1; s < 100; s+=10) { - params.max_size = s; params.max_depth = d; + params.max_size = s; + fmt::print( "Calling make_classifier...\n"); auto PRG = SS.make_classifier(0, 0, params); fmt::print( @@ -156,8 +162,12 @@ TEST(Program, FitClassifier) "=================================================\n", d, s, PRG.get_model("compact", true) ); + + fmt::print( "Fitting the model...\n"); PRG.fit(data); + fmt::print( "predict...\n"); auto y = PRG.predict(data); + fmt::print( "predict proba...\n"); auto yproba = PRG.predict_proba(data); } }