Skip to content

Commit

Permalink
Improving PTC2 to work with strong typed programs. Fixed infinite loo…
Browse files Browse the repository at this point in the history
…p in PTC2
  • Loading branch information
gAldeia committed Apr 14, 2024
1 parent 39a1e16 commit 2c438e0
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 29 deletions.
7 changes: 4 additions & 3 deletions src/data/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,10 @@ Dataset read_csv (
// check if endpoint is binary
bool binary_endpoint = (y.array() == 0 || y.array() == 1).all();

auto result = Dataset(features,y,binary_endpoint);
return result;

// using constructor 1. (initializing data from a map)
auto result = Dataset(features, y, binary_endpoint);

return result;
}

} // Brush
Expand Down
54 changes: 30 additions & 24 deletions src/search_space.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,14 @@ tree<Node> SearchSpace::PTC2(Node root, int max_d, int max_size) const

auto Tree = tree<Node>();

/* fmt::print("building program with max size {}, max depth {}",max_size,max_d); */
fmt::print("building program with max size {}, max depth {}",max_size,max_d);

// Queue of nodes that need children
vector<tuple<TreeIter, DataType, int>> queue;

/* cout << "chose " << n.name << endl; */
cout << "root " << root.name << endl;
// auto spot = Tree.set_head(n);
/* cout << "inserting...\n"; */
cout << "inserting...\n";
auto spot = Tree.insert(Tree.begin(), root);

// node depth
Expand All @@ -295,7 +295,7 @@ tree<Node> SearchSpace::PTC2(Node root, int max_d, int max_size) const
//For each argument position a of n, Enqueue(a; g)
for (auto a : root.arg_types)
{
/* cout << "queing a node of type " << DataTypeName[a] << endl; */
cout << "queing a node of type " << DataTypeName[a] << endl;
auto child_spot = Tree.append_child(spot);
queue.push_back(make_tuple(child_spot, a, d));
}
Expand All @@ -304,8 +304,8 @@ tree<Node> SearchSpace::PTC2(Node root, int max_d, int max_size) const

Node n;
// Now we actually start the PTC2 procedure to create the program tree
/* cout << "queue size: " << queue.size() << endl; */
/* cout << "entering first while loop...\n"; */
cout << "queue size: " << queue.size() << endl;
cout << "entering first while loop...\n";
while ( queue.size() + s < max_size && queue.size() > 0)
{
// including the queue size in the max_size, since each element in queue
Expand All @@ -317,14 +317,14 @@ tree<Node> SearchSpace::PTC2(Node root, int max_d, int max_size) const
// always insert a non terminal (which by default has weights off).
// this way, we can have PTC2 working properly.

/* cout << "queue size: " << queue.size() << endl; */
cout << "queue size: " << queue.size() << endl;
auto [qspot, t, d] = RandomDequeue(queue);

/* cout << "current depth: " << d << endl; */
cout << "current depth: " << d << endl;
if (d >= max_d || s >= max_size)
{
// choose terminal of matching type
/* cout << "getting " << DataTypeName[t] << " terminal\n"; */
cout << "getting " << DataTypeName[t] << " terminal\n";
// qspot = sample_terminal(t);
// Tree.replace(qspot, sample_terminal(t));
// Tree.append_child(qspot, sample_terminal(t));
Expand All @@ -344,15 +344,22 @@ tree<Node> SearchSpace::PTC2(Node root, int max_d, int max_size) const
else
{
//choose a nonterminal of matching type
/* cout << "getting op of type " << DataTypeName[t] << endl; */
cout << "getting op of type " << DataTypeName[t] << endl;
auto opt = sample_op(t);
/* cout << "chose " << n.name << endl; */
cout << "chose " << n.name << endl;
// TreeIter new_spot = Tree.append_child(qspot, n);
// qspot = n;

if (!opt) {
queue.push_back(make_tuple(qspot, t, d));
continue;
if (!opt) { // there is no operator for this node. sample a terminal instead
opt = sample_terminal(t);
}

if (!opt) { // no operator nor terminal. weird.
auto msg = fmt::format("Failed to sample operator AND terminal of data type {} during PTC2.\n", DataTypeName[t]);
HANDLE_ERROR_THROW(msg);

// queue.push_back(make_tuple(qspot, t, d));
// continue;
}

n = opt.value();
Expand All @@ -362,7 +369,7 @@ tree<Node> SearchSpace::PTC2(Node root, int max_d, int max_size) const
// For each arg of n, add to queue
for (auto a : n.arg_types)
{
/* cout << "queing a node of type " << DataTypeName[a] << endl; */
cout << "queing a node of type " << DataTypeName[a] << endl;
// queue.push_back(make_tuple(new_spot, a, d+1));
auto child_spot = Tree.append_child(newspot);

Expand All @@ -382,19 +389,20 @@ tree<Node> SearchSpace::PTC2(Node root, int max_d, int max_size) const
&& Isnt<NodeType::Constant, NodeType::MeanLabel, NodeType::OffsetSum>(n.node_type) )
s += 2;

/* cout << "current tree size: " << s << endl; */
cout << "current tree size: " << s << endl;
}
/* cout << "entering second while loop...\n"; */

cout << "entering second while loop...\n";
while (queue.size() > 0)
{
if (queue.size() == 0)
break;

/* cout << "queue size: " << queue.size() << endl; */
cout << "queue size: " << queue.size() << endl;

auto [qspot, t, d] = RandomDequeue(queue);

/* cout << "getting " << DataTypeName[t] << " terminal\n"; */
cout << "getting " << DataTypeName[t] << " terminal\n";
// Tree.append_child(qspot, sample_terminal(t));
// qspot = sample_terminal(t);
// auto newspot = Tree.replace(qspot, sample_terminal(t));
Expand All @@ -408,11 +416,9 @@ tree<Node> SearchSpace::PTC2(Node root, int max_d, int max_size) const
auto newspot = Tree.replace(qspot, n);
}

/* cout << "final tree:\n" */
/* << Tree.begin().node->get_model() << "\n" */
/* << Tree.begin().node->get_tree_model(true) << endl; */
/* << Tree.get_model() << "\n" */
/* << Tree.get_model(true) << endl; // pretty */
cout << "final tree:\n"
<< Tree.begin().node->get_model() << "\n"
<< Tree.begin().node->get_tree_model(true) << endl;

return Tree;
};
Expand Down
14 changes: 12 additions & 2 deletions tests/cpp/test_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ TEST(Program, PredictWithWeights)

Dataset data = Data::read_csv("docs/examples/datasets/d_enc.csv","label");

ASSERT_FALSE(data.classification);

SearchSpace SS;
SS.init(data);

Expand Down Expand Up @@ -138,16 +140,20 @@ TEST(Program, FitClassifier)
{
Parameters params;

Dataset data = Data::read_csv("docs/examples/datasets/d_analcatdata_aids.csv","target");
Dataset data = Data::read_csv("docs/examples/datasets/d_analcatdata_aids.csv", "target");

ASSERT_TRUE(data.classification);
SearchSpace SS;

SS.init(data);

for (int d = 1; d < 10; ++d) {
for (int s = 1; s < 100; s+=10) {

params.max_size = s;
params.max_depth = d;
params.max_size = s;

fmt::print( "Calling make_classifier...\n");
auto PRG = SS.make_classifier(0, 0, params);

fmt::print(
Expand All @@ -156,8 +162,12 @@ TEST(Program, FitClassifier)
"=================================================\n",
d, s, PRG.get_model("compact", true)
);

fmt::print( "Fitting the model...\n");
PRG.fit(data);
fmt::print( "predict...\n");
auto y = PRG.predict(data);
fmt::print( "predict proba...\n");
auto yproba = PRG.predict_proba(data);
}
}
Expand Down

0 comments on commit 2c438e0

Please sign in to comment.