Skip to content

Commit

Permalink
Logistic(Add(Const, <>))
Browse files Browse the repository at this point in the history
  • Loading branch information
gAldeia committed Apr 16, 2024
1 parent c2d3648 commit 4f077a7
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 25 deletions.
30 changes: 20 additions & 10 deletions src/search_space.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ void SearchSpace::init(const Dataset& d, const unordered_map<string,float>& user

std::set<float> unique_classes(vec.begin(), vec.end());

// We need some ops in the search space so we can have the logit and offset
if (user_ops.find("OffsetSum") == user_ops.end())
extended_user_ops.insert({"OffsetSum", 0.0f});

if (unique_classes.size()==2 && (user_ops.find("Logistic") == user_ops.end())) {
extended_user_ops.insert({"Logistic", 0.0f});
}
Expand Down Expand Up @@ -249,13 +253,19 @@ std::optional<tree<Node>> SearchSpace::sample_subtree(Node root, int max_d, int
terminal_weights.at(root.ret_type).end())) )
return std::nullopt;

auto Tree = tree<Node>();
auto spot = Tree.insert(Tree.begin(), root);

// we should notice the difference between size of a PROGRAM and a TREE.
// program count weights in its size, while the TREE structure dont. Wenever
// using size of a program/tree, make sure you use the function from the correct class
return PTC2(root, max_d, max_size);
PTC2(Tree, spot, max_d, max_size);

return Tree;
};

tree<Node> SearchSpace::PTC2(Node root, int max_d, int max_size) const
tree<Node>& SearchSpace::PTC2(tree<Node>& Tree,
tree<Node>::iterator spot, int max_d, int max_size) const
{
// PTC2 is agnostic of program type

Expand All @@ -265,31 +275,31 @@ tree<Node> SearchSpace::PTC2(Node root, int max_d, int max_size) const
// parameters, the real maximum size that can occur is `max_size` plus the
// highest operator arity, and the real maximum depth is `max_depth` plus one.

auto Tree = tree<Node>();
// auto Tree = tree<Node>();

fmt::print("building program with max size {}, max depth {}",max_size,max_d);

// Queue of nodes that need children
vector<tuple<TreeIter, DataType, int>> queue;

cout << "root " << root.name << endl;
// auto spot = Tree.set_head(n);
cout << "inserting...\n";
auto spot = Tree.insert(Tree.begin(), root);

// node depth
int d = 1;
// current tree size
int s = 1;

Node root = spot.node->data;

cout << "root " << root.name << endl;
// auto spot = Tree.set_head(n);

// updating size accordingly to root node
if (Is<NodeType::SplitBest>(root.node_type))
s += 3;
else if (Is<NodeType::SplitOn>(root.node_type))
s += 2;

if ( root.get_is_weighted()==true
&& Isnt<NodeType::Constant, NodeType::MeanLabel, NodeType::OffsetSum>(root.node_type) )
&& Isnt<NodeType::Constant, NodeType::MeanLabel>(root.node_type) )
s += 2;

//For each argument position a of n, Enqueue(a; g)
Expand Down Expand Up @@ -386,7 +396,7 @@ tree<Node> SearchSpace::PTC2(Node root, int max_d, int max_size) const
s += 2;

if ( n.get_is_weighted()==true
&& Isnt<NodeType::Constant, NodeType::MeanLabel, NodeType::OffsetSum>(n.node_type) )
&& Isnt<NodeType::Constant, NodeType::MeanLabel>(n.node_type) )
s += 2;

cout << "current tree size: " << s << endl;
Expand Down
53 changes: 39 additions & 14 deletions src/search_space.h
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ struct SearchSpace
void print() const;

private:
tree<Node> PTC2(Node root, int max_d, int max_size) const;
tree<Node>& PTC2(tree<Node>& Tree, tree<Node>::iterator root, int max_d, int max_size) const;

template<NodeType NT, typename S>
requires (!is_in_v<NT, NodeType::Terminal, NodeType::Constant, NodeType::MeanLabel>)
Expand Down Expand Up @@ -691,38 +691,63 @@ P SearchSpace::make_program(const Parameters& params, int max_d, int max_size)
ProgramType program_type = P::program_type;
// ProgramType program_type = ProgramTypeEnum<PT>::value;

// building the root node for each program case. We give the root, and it
// fills the rest of the tree
Node root;
// Tree is pre-filled with some fixed nodes depending on program type
auto Tree = tree<Node>();

// building the tree for each program case. Then, we give the spot to PTC2,
// and it will fill the rest of the tree
tree<Node>::iterator spot;

// building the root node for each program case
if (P::program_type == ProgramType::BinaryClassifier)
{
root = get(NodeType::Logistic, DataType::ArrayF, Signature<ArrayXf(ArrayXf)>());
root.set_prob_change(0.0);
root.fixed=true;
Node node_logit = get(NodeType::Logistic, DataType::ArrayF, Signature<ArrayXf(ArrayXf)>());
node_logit.set_prob_change(0.0);
node_logit.fixed=true;
auto spot_logit = Tree.insert(Tree.begin(), node_logit);

if (true) { // Logistic(Add(Constant, <>)).
Node node_offset = get(NodeType::OffsetSum, DataType::ArrayF, Signature<ArrayXf(ArrayXf)>());
node_offset.set_prob_change(0.0);
node_offset.fixed=true;

auto spot_offset = Tree.append_child(spot_logit);

spot = Tree.replace(spot_offset, node_offset);
}
else { // If false, then model will be Logistic(<>)
spot = spot_logit;
}
}
else if (P::program_type == ProgramType::MulticlassClassifier)
{
root = get(NodeType::Softmax, DataType::MatrixF, Signature<ArrayXXf(ArrayXXf)>());
root.set_prob_change(0.0);
root.fixed=true;
Node node_softmax = get(NodeType::Softmax, DataType::MatrixF, Signature<ArrayXXf(ArrayXXf)>());
node_softmax.set_prob_change(0.0);
node_softmax.fixed=true;

spot = Tree.insert(Tree.begin(), node_softmax);
}
else {
else // regression or representer --- sampling any candidate op or terminal
{
Node root;

std::optional<Node> opt=std::nullopt;

if (max_size>1 && max_d>1)
opt = sample_op(root_type);

if (!opt) // if failed, then we dont have any operator to use as root...
opt = sample_terminal(root_type, true);

root = opt.value();
}

spot = Tree.insert(Tree.begin(), root);
}

// max_d-1 because we always pick the root before calling ptc2
auto Tree = PTC2(root, max_d-1, max_size);
PTC2(Tree, spot, max_d-1, max_size); // change inplace

return P(*this,Tree);
return P(*this, Tree);
};

extern SearchSpace SS;
Expand Down
2 changes: 1 addition & 1 deletion src/variation.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class MutationBase {
acc += 2;

if ( (include_weight && node.get_is_weighted()==true)
&& Isnt<NodeType::Constant, NodeType::MeanLabel, NodeType::OffsetSum>(node.node_type) )
&& Isnt<NodeType::Constant, NodeType::MeanLabel>(node.node_type) )
// Taking into account the weight and multiplication, if enabled.
// weighted constants still count as 1 (simpler than constant terminals)
acc += 2;
Expand Down

0 comments on commit 4f077a7

Please sign in to comment.