From 4f077a741a90a8868519bdf7660b3ccfc2f20a76 Mon Sep 17 00:00:00 2001 From: gAldeia Date: Tue, 16 Apr 2024 09:56:33 -0300 Subject: [PATCH] Logistic(Add(Const, <>)) --- src/search_space.cpp | 30 ++++++++++++++++--------- src/search_space.h | 53 ++++++++++++++++++++++++++++++++------------ src/variation.h | 2 +- 3 files changed, 60 insertions(+), 25 deletions(-) diff --git a/src/search_space.cpp b/src/search_space.cpp index 6aaf45cb..a6b11e8b 100644 --- a/src/search_space.cpp +++ b/src/search_space.cpp @@ -201,6 +201,10 @@ void SearchSpace::init(const Dataset& d, const unordered_map& user std::set unique_classes(vec.begin(), vec.end()); + // We need some ops in the search space so we can have the logit and offset + if (user_ops.find("OffsetSum") == user_ops.end()) + extended_user_ops.insert({"OffsetSum", 0.0f}); + if (unique_classes.size()==2 && (user_ops.find("Logistic") == user_ops.end())) { extended_user_ops.insert({"Logistic", 0.0f}); } @@ -249,13 +253,19 @@ std::optional> SearchSpace::sample_subtree(Node root, int max_d, int terminal_weights.at(root.ret_type).end())) ) return std::nullopt; + auto Tree = tree(); + auto spot = Tree.insert(Tree.begin(), root); + // we should notice the difference between size of a PROGRAM and a TREE. // program count weights in its size, while the TREE structure dont. Wenever // using size of a program/tree, make sure you use the function from the correct class - return PTC2(root, max_d, max_size); + PTC2(Tree, spot, max_d, max_size); + + return Tree; }; -tree SearchSpace::PTC2(Node root, int max_d, int max_size) const +tree& SearchSpace::PTC2(tree& Tree, + tree::iterator spot, int max_d, int max_size) const { // PTC2 is agnostic of program type @@ -265,23 +275,23 @@ tree SearchSpace::PTC2(Node root, int max_d, int max_size) const // parameters, the real maximum size that can occur is `max_size` plus the // highest operator arity, and the real maximum depth is `max_depth` plus one. - auto Tree = tree(); + // auto Tree = tree(); fmt::print("building program with max size {}, max depth {}",max_size,max_d); // Queue of nodes that need children vector> queue; - cout << "root " << root.name << endl; - // auto spot = Tree.set_head(n); - cout << "inserting...\n"; - auto spot = Tree.insert(Tree.begin(), root); - // node depth int d = 1; // current tree size int s = 1; + Node root = spot.node->data; + + cout << "root " << root.name << endl; + // auto spot = Tree.set_head(n); + // updating size accordingly to root node if (Is(root.node_type)) s += 3; @@ -289,7 +299,7 @@ tree SearchSpace::PTC2(Node root, int max_d, int max_size) const s += 2; if ( root.get_is_weighted()==true - && Isnt(root.node_type) ) + && Isnt(root.node_type) ) s += 2; //For each argument position a of n, Enqueue(a; g) @@ -386,7 +396,7 @@ tree SearchSpace::PTC2(Node root, int max_d, int max_size) const s += 2; if ( n.get_is_weighted()==true - && Isnt(n.node_type) ) + && Isnt(n.node_type) ) s += 2; cout << "current tree size: " << s << endl; diff --git a/src/search_space.h b/src/search_space.h index 83883360..ef0ea309 100644 --- a/src/search_space.h +++ b/src/search_space.h @@ -579,7 +579,7 @@ struct SearchSpace void print() const; private: - tree PTC2(Node root, int max_d, int max_size) const; + tree& PTC2(tree& Tree, tree::iterator root, int max_d, int max_size) const; template requires (!is_in_v) @@ -691,24 +691,46 @@ P SearchSpace::make_program(const Parameters& params, int max_d, int max_size) ProgramType program_type = P::program_type; // ProgramType program_type = ProgramTypeEnum::value; - // building the root node for each program case. We give the root, and it - // fills the rest of the tree - Node root; + // Tree is pre-filled with some fixed nodes depending on program type + auto Tree = tree(); + + // building the tree for each program case. Then, we give the spot to PTC2, + // and it will fill the rest of the tree + tree::iterator spot; // building the root node for each program case if (P::program_type == ProgramType::BinaryClassifier) { - root = get(NodeType::Logistic, DataType::ArrayF, Signature()); - root.set_prob_change(0.0); - root.fixed=true; + Node node_logit = get(NodeType::Logistic, DataType::ArrayF, Signature()); + node_logit.set_prob_change(0.0); + node_logit.fixed=true; + auto spot_logit = Tree.insert(Tree.begin(), node_logit); + + if (true) { // Logistic(Add(Constant, <>)). + Node node_offset = get(NodeType::OffsetSum, DataType::ArrayF, Signature()); + node_offset.set_prob_change(0.0); + node_offset.fixed=true; + + auto spot_offset = Tree.append_child(spot_logit); + + spot = Tree.replace(spot_offset, node_offset); + } + else { // If false, then model will be Logistic(<>) + spot = spot_logit; + } } else if (P::program_type == ProgramType::MulticlassClassifier) { - root = get(NodeType::Softmax, DataType::MatrixF, Signature()); - root.set_prob_change(0.0); - root.fixed=true; + Node node_softmax = get(NodeType::Softmax, DataType::MatrixF, Signature()); + node_softmax.set_prob_change(0.0); + node_softmax.fixed=true; + + spot = Tree.insert(Tree.begin(), node_softmax); } - else { + else // regression or representer --- sampling any candidate op or terminal + { + Node root; + std::optional opt=std::nullopt; if (max_size>1 && max_d>1) @@ -716,13 +738,16 @@ P SearchSpace::make_program(const Parameters& params, int max_d, int max_size) if (!opt) // if failed, then we dont have any operator to use as root... opt = sample_terminal(root_type, true); + root = opt.value(); - } + spot = Tree.insert(Tree.begin(), root); + } + // max_d-1 because we always pick the root before calling ptc2 - auto Tree = PTC2(root, max_d-1, max_size); + PTC2(Tree, spot, max_d-1, max_size); // change inplace - return P(*this,Tree); + return P(*this, Tree); }; extern SearchSpace SS; diff --git a/src/variation.h b/src/variation.h index bbf3af63..0bca47dc 100644 --- a/src/variation.h +++ b/src/variation.h @@ -83,7 +83,7 @@ class MutationBase { acc += 2; if ( (include_weight && node.get_is_weighted()==true) - && Isnt(node.node_type) ) + && Isnt(node.node_type) ) // Taking into account the weight and multiplication, if enabled. // weighted constants still count as 1 (simpler than constant terminals) acc += 2;