diff --git a/vowpalwabbit/core/include/vw/core/example.h b/vowpalwabbit/core/include/vw/core/example.h index 99f7e8bb0cd..c13a8975f3f 100644 --- a/vowpalwabbit/core/include/vw/core/example.h +++ b/vowpalwabbit/core/include/vw/core/example.h @@ -143,26 +143,8 @@ class example : public example_predict // core example datatype. class workspace; -class flat_example -{ -public: - polylabel l; - reduction_features ex_reduction_features; - - VW::v_array tag; // An identifier for the example. - - size_t example_counter; - uint64_t ft_offset; - float global_weight; - - size_t num_features; // precomputed, cause it's fast&easy. - float total_sum_feat_sq; // precomputed, cause it's kind of fast & easy. - features fs; // all the features -}; - -flat_example* flatten_example(VW::workspace& all, example* ec); -flat_example* flatten_sort_example(VW::workspace& all, example* ec); -void free_flatten_example(flat_example* fec); +// TODO: make workspace and example const +void flatten_features(VW::workspace& all, example& ec, features& fs); inline bool example_is_newline(const example& ec) { return ec.is_newline; } @@ -194,13 +176,6 @@ void truncate_example_namespace(VW::example& ec, VW::namespace_index ns, const f void append_example_namespaces_from_example(VW::example& target, const VW::example& source); void truncate_example_namespaces_from_example(VW::example& target, const VW::example& source); } // namespace details - -namespace model_utils -{ -size_t read_model_field(io_buf& io, flat_example& fe, VW::label_parser& lbl_parser); -size_t write_model_field(io_buf& io, const flat_example& fe, const std::string& upstream_name, bool text, - VW::label_parser& lbl_parser, uint64_t parse_mask); -} // namespace model_utils } // namespace VW // Deprecated compat definitions @@ -209,18 +184,6 @@ using polylabel VW_DEPRECATED("polylabel moved into VW namespace") = VW::polylab using polyprediction VW_DEPRECATED("polyprediction moved into VW namespace") = VW::polyprediction; using example VW_DEPRECATED("example moved into VW namespace") = VW::example; using multi_ex VW_DEPRECATED("multi_ex moved into VW namespace") = VW::multi_ex; -using flat_example VW_DEPRECATED("flat_example moved into VW namespace") = VW::flat_example; - -VW_DEPRECATED("flatten_example moved into VW namespace") -inline VW::flat_example* flatten_example(VW::workspace& all, VW::example* ec) { return VW::flatten_example(all, ec); } - -VW_DEPRECATED("flatten_sort_example moved into VW namespace") -inline VW::flat_example* flatten_sort_example(VW::workspace& all, VW::example* ec) -{ - return VW::flatten_sort_example(all, ec); -} -VW_DEPRECATED("free_flatten_example moved into VW namespace") -inline void free_flatten_example(VW::flat_example* fec) { return VW::free_flatten_example(fec); } VW_DEPRECATED("example_is_newline moved into VW namespace") inline bool example_is_newline(const VW::example& ec) { return VW::example_is_newline(ec); } diff --git a/vowpalwabbit/core/include/vw/core/feature_group.h b/vowpalwabbit/core/include/vw/core/feature_group.h index f19e7ea4a80..569c54ca482 100644 --- a/vowpalwabbit/core/include/vw/core/feature_group.h +++ b/vowpalwabbit/core/include/vw/core/feature_group.h @@ -510,6 +510,10 @@ class features return all_extents_complete; } }; + +/// Both fs1 and fs2 must be sorted. +/// Most often used with VW::flatten_features +float features_dot_product(const features& fs1, const features& fs2); } // namespace VW using feature_value VW_DEPRECATED("Moved into VW namespace. Will be removed in VW 10.") = VW::feature_value; diff --git a/vowpalwabbit/core/src/example.cc b/vowpalwabbit/core/src/example.cc index dbd671c825a..2f726e55e58 100644 --- a/vowpalwabbit/core/src/example.cc +++ b/vowpalwabbit/core/src/example.cc @@ -48,6 +48,9 @@ float VW::example::get_total_sum_feat_sq() float collision_cleanup(VW::features& fs) { + // Input must be sorted. + assert(std::is_sorted(fs.indices.begin(), fs.indices.end())); + // This loops over the sequence of feature values and their indexes // when an index is repeated this combines them by adding their values. // This assumes that fs is sorted (which is the case in `flatten_sort_example`). @@ -105,46 +108,23 @@ void vec_ffs_store(full_features_and_source& p, float fx, uint64_t fi) } namespace VW { -flat_example* flatten_example(VW::workspace& all, example* ec) -{ - flat_example& fec = VW::details::calloc_or_throw(); - fec.l = ec->l; - fec.tag = ec->tag; - fec.ex_reduction_features = ec->ex_reduction_features; - fec.example_counter = ec->example_counter; - fec.ft_offset = ec->ft_offset; - fec.num_features = ec->num_features; +void flatten_features(VW::workspace& all, example& ec, features& fs) +{ + fs.clear(); full_features_and_source ffs; + ffs.fs = std::move(fs); ffs.stride_shift = all.weights.stride_shift(); if (all.weights.not_null()) - { // TODO:temporary fix. all.weights is not initialized at this point in some cases. + { + // TODO:temporary fix. all.weights is not initialized at this point in some cases. ffs.mask = all.weights.mask() >> all.weights.stride_shift(); } else { ffs.mask = static_cast(LONG_MAX) >> all.weights.stride_shift(); } - VW::foreach_feature(all, *ec, ffs); - - std::swap(fec.fs, ffs.fs); - - return &fec; -} - -flat_example* flatten_sort_example(VW::workspace& all, example* ec) -{ - flat_example* fec = flatten_example(all, ec); - fec->fs.sort(all.parse_mask); - fec->total_sum_feat_sq = collision_cleanup(fec->fs); - return fec; -} - -void free_flatten_example(flat_example* fec) -{ - // note: The label memory should be freed by by freeing the original example. - if (fec) - { - fec->fs.~features(); - free(fec); - } + VW::foreach_feature(all, ec, ffs); + ffs.fs.sort(all.parse_mask); + ffs.fs.sum_feat_sq = collision_cleanup(ffs.fs); + fs = std::move(ffs.fs); } void return_multiple_example(VW::workspace& all, VW::multi_ex& examples) @@ -213,42 +193,6 @@ void truncate_example_namespaces_from_example(VW::example& target, const VW::exa } } } // namespace details - -namespace model_utils -{ -size_t read_model_field(io_buf& io, flat_example& fe, VW::label_parser& lbl_parser) -{ - size_t bytes = 0; - lbl_parser.default_label(fe.l); - bytes += lbl_parser.read_cached_label(fe.l, fe.ex_reduction_features, io); - bytes += read_model_field(io, fe.tag); - bytes += read_model_field(io, fe.example_counter); - bytes += read_model_field(io, fe.ft_offset); - bytes += read_model_field(io, fe.global_weight); - bytes += read_model_field(io, fe.num_features); - bytes += read_model_field(io, fe.total_sum_feat_sq); - unsigned char index = 0; - bytes += ::VW::parsers::cache::details::read_cached_index(io, index); - bool sorted = true; - bytes += ::VW::parsers::cache::details::read_cached_features(io, fe.fs, sorted); - return bytes; -} -size_t write_model_field(io_buf& io, const flat_example& fe, const std::string& upstream_name, bool text, - VW::label_parser& lbl_parser, uint64_t parse_mask) -{ - size_t bytes = 0; - lbl_parser.cache_label(fe.l, fe.ex_reduction_features, io, upstream_name + "_label", text); - bytes += write_model_field(io, fe.tag, upstream_name + "_tag", text); - bytes += write_model_field(io, fe.example_counter, upstream_name + "_example_counter", text); - bytes += write_model_field(io, fe.ft_offset, upstream_name + "_ft_offset", text); - bytes += write_model_field(io, fe.global_weight, upstream_name + "_global_weight", text); - bytes += write_model_field(io, fe.num_features, upstream_name + "_num_features", text); - bytes += write_model_field(io, fe.total_sum_feat_sq, upstream_name + "_total_sum_feat_sq", text); - ::VW::parsers::cache::details::cache_index(io, 0); - ::VW::parsers::cache::details::cache_features(io, fe.fs, parse_mask); - return bytes; -} -} // namespace model_utils } // namespace VW namespace VW diff --git a/vowpalwabbit/core/src/feature_group.cc b/vowpalwabbit/core/src/feature_group.cc index 64f05a87390..ccb50e30484 100644 --- a/vowpalwabbit/core/src/feature_group.cc +++ b/vowpalwabbit/core/src/feature_group.cc @@ -297,3 +297,28 @@ void VW::features::end_ns_extent() } } } + +float VW::features_dot_product(const features& fs1, const features& fs2) +{ + assert(std::is_sorted(fs1.indices.begin(), fs1.indices.end())); + assert(std::is_sorted(fs2.indices.begin(), fs2.indices.end())); + + float dotprod = 0; + if (fs2.indices.empty()) { return 0.f; } + + for (size_t idx1 = 0, idx2 = 0; idx1 < fs1.size() && idx2 < fs2.size(); idx1++) + { + uint64_t ec1pos = fs1.indices[idx1]; + uint64_t ec2pos = fs2.indices[idx2]; + if (ec1pos < ec2pos) { continue; } + + while (ec1pos > ec2pos && ++idx2 < fs2.size()) { ec2pos = fs2.indices[idx2]; } + + if (ec1pos == ec2pos) + { + dotprod += fs1.values[idx1] * fs2.values[idx2]; + ++idx2; + } + } + return dotprod; +} \ No newline at end of file diff --git a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc index a22b3fd0953..99f2cdabffa 100644 --- a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc +++ b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc @@ -11,6 +11,7 @@ #include "vw/config/options.h" #include "vw/core/array_parameters.h" #include "vw/core/example.h" +#include "vw/core/feature_group.h" #include "vw/core/learner.h" #include "vw/core/memory.h" #include "vw/core/model_utils.h" @@ -75,14 +76,14 @@ emt_example::emt_example(VW::workspace& all, VW::example* ex) std::vector> base_interactions; ex->interactions = &base_interactions; - auto* ex1 = VW::flatten_sort_example(all, ex); - for (auto& f : ex1->fs) { base.emplace_back(f.index(), f.value()); } - VW::free_flatten_example(ex1); + VW::features fs; + VW::flatten_features(all, *ex, fs); + for (auto& f : fs) { base.emplace_back(f.index(), f.value()); } + fs.clear(); ex->interactions = full_interactions; - auto* ex2 = VW::flatten_sort_example(all, ex); - for (auto& f : ex2->fs) { full.emplace_back(f.index(), f.value()); } - VW::free_flatten_example(ex2); + VW::flatten_features(all, *ex, fs); + for (auto& f : fs) { full.emplace_back(f.index(), f.value()); } } emt_lru::emt_lru(uint64_t max_size) : max_size(max_size) {} diff --git a/vowpalwabbit/core/src/reductions/kernel_svm.cc b/vowpalwabbit/core/src/reductions/kernel_svm.cc index 3112a8eddf0..9465e9f0478 100644 --- a/vowpalwabbit/core/src/reductions/kernel_svm.cc +++ b/vowpalwabbit/core/src/reductions/kernel_svm.cc @@ -9,6 +9,7 @@ #include "vw/core/accumulate.h" #include "vw/core/constant.h" #include "vw/core/example.h" +#include "vw/core/feature_group.h" #include "vw/core/learner.h" #include "vw/core/loss_functions.h" #include "vw/core/memory.h" @@ -43,6 +44,75 @@ using std::endl; namespace { + +class flat_example +{ +public: + VW::polylabel l; + VW::reduction_features ex_reduction_features; + + VW::v_array tag; // An identifier for the example. + + size_t example_counter; + uint64_t ft_offset; + float global_weight; + + size_t num_features; // precomputed, cause it's fast&easy. + float total_sum_feat_sq; // precomputed, cause it's kind of fast & easy. + VW::features fs; // all the features +}; + +flat_example* flatten_sort_example(VW::workspace& all, VW::example* ec) +{ + auto& fec = VW::details::calloc_or_throw(); + fec.l = ec->l; + fec.tag = ec->tag; + fec.ex_reduction_features = ec->ex_reduction_features; + fec.example_counter = ec->example_counter; + fec.ft_offset = ec->ft_offset; + fec.num_features = ec->num_features; + + flatten_features(all, *ec, fec.fs); + fec.total_sum_feat_sq = fec.fs.sum_feat_sq; + + return &fec; +} + +// TODO: do not depend on unstable cache format for model format of KSVM. +size_t read_model_field_flat_example(VW::io_buf& io, flat_example& fe, VW::label_parser& lbl_parser) +{ + size_t bytes = 0; + lbl_parser.default_label(fe.l); + bytes += lbl_parser.read_cached_label(fe.l, fe.ex_reduction_features, io); + bytes += VW::model_utils::read_model_field(io, fe.tag); + bytes += VW::model_utils::read_model_field(io, fe.example_counter); + bytes += VW::model_utils::read_model_field(io, fe.ft_offset); + bytes += VW::model_utils::read_model_field(io, fe.global_weight); + bytes += VW::model_utils::read_model_field(io, fe.num_features); + bytes += VW::model_utils::read_model_field(io, fe.total_sum_feat_sq); + unsigned char index = 0; + bytes += ::VW::parsers::cache::details::read_cached_index(io, index); + bool sorted = true; + bytes += ::VW::parsers::cache::details::read_cached_features(io, fe.fs, sorted); + return bytes; +} + +size_t write_model_field_flat_example(VW::io_buf& io, const flat_example& fe, const std::string& upstream_name, + bool text, VW::label_parser& lbl_parser, uint64_t parse_mask) +{ + size_t bytes = 0; + lbl_parser.cache_label(fe.l, fe.ex_reduction_features, io, upstream_name + "_label", text); + bytes += VW::model_utils::write_model_field(io, fe.tag, upstream_name + "_tag", text); + bytes += VW::model_utils::write_model_field(io, fe.example_counter, upstream_name + "_example_counter", text); + bytes += VW::model_utils::write_model_field(io, fe.ft_offset, upstream_name + "_ft_offset", text); + bytes += VW::model_utils::write_model_field(io, fe.global_weight, upstream_name + "_global_weight", text); + bytes += VW::model_utils::write_model_field(io, fe.num_features, upstream_name + "_num_features", text); + bytes += VW::model_utils::write_model_field(io, fe.total_sum_feat_sq, upstream_name + "_total_sum_feat_sq", text); + ::VW::parsers::cache::details::cache_index(io, 0); + ::VW::parsers::cache::details::cache_features(io, fe.fs, parse_mask); + return bytes; +} + class svm_params; static size_t num_kernel_evals = 0; @@ -52,10 +122,10 @@ class svm_example { public: VW::v_array krow; - VW::flat_example ex; + flat_example ex; ~svm_example(); - void init_svm_example(VW::flat_example* fec); + void init_svm_example(flat_example* fec); int compute_kernels(svm_params& params); int clear_kernels(); }; @@ -132,7 +202,7 @@ class svm_params } }; -void svm_example::init_svm_example(VW::flat_example* fec) +void svm_example::init_svm_example(flat_example* fec) { ex = std::move(*fec); free(fec); @@ -146,7 +216,7 @@ svm_example::~svm_example() // free_flatten_example(fec); // free contents of flat example and frees fec. } -float kernel_function(const VW::flat_example* fec1, const VW::flat_example* fec2, void* params, size_t kernel_type); +float kernel_function(const flat_example* fec1, const flat_example* fec2, void* params, size_t kernel_type); int svm_example::compute_kernels(svm_params& params) { @@ -252,15 +322,15 @@ void save_load_svm_model(svm_params& params, VW::io_buf& model_file, bool read, { if (read) { - auto fec = VW::make_unique(); + auto fec = VW::make_unique(); auto* tmp = &VW::details::calloc_or_throw(); - VW::model_utils::read_model_field(model_file, *fec, params.all->example_parser->lbl_parser); + read_model_field_flat_example(model_file, *fec, params.all->example_parser->lbl_parser); tmp->ex = *fec; model->support_vec.push_back(tmp); } else { - VW::model_utils::write_model_field(model_file, model->support_vec[i]->ex, "_flat_example", false, + write_model_field_flat_example(model_file, model->support_vec[i]->ex, "_flat_example", false, params.all->example_parser->lbl_parser, params.all->parse_mask); } } @@ -289,45 +359,19 @@ void save_load(svm_params& params, VW::io_buf& model_file, bool read, bool text) save_load_svm_model(params, model_file, read, text); } -float linear_kernel(const VW::flat_example* fec1, const VW::flat_example* fec2) -{ - float dotprod = 0; - - auto& fs_1 = const_cast(fec1->fs); - auto& fs_2 = const_cast(fec2->fs); - if (fs_2.indices.size() == 0) { return 0.f; } - - for (size_t idx1 = 0, idx2 = 0; idx1 < fs_1.size() && idx2 < fs_2.size(); idx1++) - { - uint64_t ec1pos = fs_1.indices[idx1]; - uint64_t ec2pos = fs_2.indices[idx2]; - // params.all->opts_n_args.trace_message<x<<" "<x<< endl; - if (ec1pos < ec2pos) { continue; } - - while (ec1pos > ec2pos && ++idx2 < fs_2.size()) { ec2pos = fs_2.indices[idx2]; } - - if (ec1pos == ec2pos) - { - dotprod += fs_1.values[idx1] * fs_2.values[idx2]; - ++idx2; - } - } - return dotprod; -} - -float poly_kernel(const VW::flat_example* fec1, const VW::flat_example* fec2, int power) +float poly_kernel(const flat_example* fec1, const flat_example* fec2, int power) { - float dotprod = linear_kernel(fec1, fec2); + float dotprod = VW::features_dot_product(fec1->fs, fec2->fs); return static_cast(std::pow(1 + dotprod, power)); } -float rbf_kernel(const VW::flat_example* fec1, const VW::flat_example* fec2, float bandwidth) +float rbf_kernel(const flat_example* fec1, const flat_example* fec2, float bandwidth) { - float dotprod = linear_kernel(fec1, fec2); + float dotprod = VW::features_dot_product(fec1->fs, fec2->fs); return expf(-(fec1->total_sum_feat_sq + fec2->total_sum_feat_sq - 2 * dotprod) * bandwidth); } -float kernel_function(const VW::flat_example* fec1, const VW::flat_example* fec2, void* params, size_t kernel_type) +float kernel_function(const flat_example* fec1, const flat_example* fec2, void* params, size_t kernel_type) { switch (kernel_type) { @@ -336,7 +380,7 @@ float kernel_function(const VW::flat_example* fec1, const VW::flat_example* fec2 case SVM_KER_POLY: return poly_kernel(fec1, fec2, *(static_cast(params))); case SVM_KER_LIN: - return linear_kernel(fec1, fec2); + return VW::features_dot_product(fec1->fs, fec2->fs); } return 0; } @@ -365,7 +409,7 @@ void predict(svm_params& params, svm_example** ec_arr, float* scores, size_t n) void predict(svm_params& params, VW::example& ec) { - VW::flat_example* fec = VW::flatten_sort_example(*(params.all), &ec); + flat_example* fec = flatten_sort_example(*(params.all), &ec); if (fec) { svm_example* sec = &VW::details::calloc_or_throw(); @@ -502,15 +546,14 @@ void sync_queries(VW::workspace& all, svm_params& params, bool* train_pool) VW::io_buf* b = new VW::io_buf(); char* queries; - VW::flat_example* fec = nullptr; + flat_example* fec = nullptr; for (size_t i = 0; i < params.pool_pos; i++) { if (!train_pool[i]) { continue; } fec = &(params.pool[i]->ex); - VW::model_utils::write_model_field( - *b, *fec, "_flat_example", false, all.example_parser->lbl_parser, all.parse_mask); + write_model_field_flat_example(*b, *fec, "_flat_example", false, all.example_parser->lbl_parser, all.parse_mask); delete params.pool[i]; } @@ -539,7 +582,7 @@ void sync_queries(VW::workspace& all, svm_params& params, bool* train_pool) for (size_t i = 0; i < params.pool_size; i++) { - if (!VW::model_utils::read_model_field(*b, *fec, all.example_parser->lbl_parser)) + if (!read_model_field_flat_example(*b, *fec, all.example_parser->lbl_parser)) { params.pool[i] = &VW::details::calloc_or_throw(); params.pool[i]->init_svm_example(fec); @@ -672,7 +715,7 @@ void train(svm_params& params) void learn(svm_params& params, VW::example& ec) { - VW::flat_example* fec = VW::flatten_sort_example(*(params.all), &ec); + flat_example* fec = flatten_sort_example(*(params.all), &ec); if (fec) { svm_example* sec = &VW::details::calloc_or_throw(); diff --git a/vowpalwabbit/core/src/reductions/memory_tree.cc b/vowpalwabbit/core/src/reductions/memory_tree.cc index 7f5441a8e14..86a5b08fd0f 100644 --- a/vowpalwabbit/core/src/reductions/memory_tree.cc +++ b/vowpalwabbit/core/src/reductions/memory_tree.cc @@ -8,6 +8,7 @@ #include "vw/common/random.h" #include "vw/config/options.h" #include "vw/core/example.h" +#include "vw/core/feature_group.h" #include "vw/core/learner.h" #include "vw/core/multiclass.h" #include "vw/core/multilabel.h" @@ -226,39 +227,14 @@ class memory_tree } }; -float linear_kernel(const VW::flat_example* fec1, const VW::flat_example* fec2) -{ - float dotprod = 0; - - auto& fs_1 = const_cast(fec1->fs); - auto& fs_2 = const_cast(fec2->fs); - if (fs_2.indices.size() == 0) { return 0.f; } - - for (size_t idx1 = 0, idx2 = 0; idx1 < fs_1.size() && idx2 < fs_2.size(); idx1++) - { - uint64_t ec1pos = fs_1.indices[idx1]; - uint64_t ec2pos = fs_2.indices[idx2]; - if (ec1pos < ec2pos) { continue; } - - while (ec1pos > ec2pos && ++idx2 < fs_2.size()) { ec2pos = fs_2.indices[idx2]; } - - if (ec1pos == ec2pos) - { - dotprod += fs_1.values[idx1] * fs_2.values[idx2]; - ++idx2; - } - } - return dotprod; -} - float normalized_linear_prod(memory_tree& b, VW::example* ec1, VW::example* ec2) { - VW::flat_example* fec1 = VW::flatten_sort_example(*b.all, ec1); - VW::flat_example* fec2 = VW::flatten_sort_example(*b.all, ec2); - float norm_sqrt = std::pow(fec1->total_sum_feat_sq * fec2->total_sum_feat_sq, 0.5f); - float linear_prod = linear_kernel(fec1, fec2); - VW::free_flatten_example(fec1); - VW::free_flatten_example(fec2); + VW::features fs1; + VW::features fs2; + flatten_features(*b.all, *ec1, fs1); + flatten_features(*b.all, *ec2, fs2); + float norm_sqrt = std::pow(fs1.sum_feat_sq * fs2.sum_feat_sq, 0.5f); + float linear_prod = VW::features_dot_product(fs1, fs2); return linear_prod / norm_sqrt; } diff --git a/vowpalwabbit/core/tests/flat_example_test.cc b/vowpalwabbit/core/tests/flat_example_test.cc index 80641bb1dee..f5ecb35cd44 100644 --- a/vowpalwabbit/core/tests/flat_example_test.cc +++ b/vowpalwabbit/core/tests/flat_example_test.cc @@ -16,12 +16,12 @@ TEST(FlatExample, SansInteractionTest) auto vw = VW::initialize(vwtest::make_args("--quiet", "--noconstant")); auto* ex = VW::read_example(*vw, "1 |x a:2 |y b:3"); - auto& flat = *VW::flatten_sort_example(*vw, ex); + VW::features fs; + VW::flatten_features(*vw, *ex, fs); - EXPECT_THAT(flat.fs.values, testing::UnorderedElementsAre(2, 3)); - EXPECT_EQ(flat.total_sum_feat_sq, 13); + EXPECT_THAT(fs.values, testing::UnorderedElementsAre(2, 3)); + EXPECT_EQ(fs.sum_feat_sq, 13); - VW::free_flatten_example(&flat); VW::finish_example(*vw, *ex); } @@ -30,12 +30,12 @@ TEST(FlatExample, WithInteractionTest) auto vw = VW::initialize(vwtest::make_args("--interactions", "xy", "--quiet", "--noconstant")); auto* ex = VW::read_example(*vw, "1 |x a:2 |y b:3"); - auto& flat = *VW::flatten_sort_example(*vw, ex); + VW::features fs; + VW::flatten_features(*vw, *ex, fs); - EXPECT_THAT(flat.fs.values, testing::UnorderedElementsAre(2, 3, 6)); - EXPECT_EQ(flat.total_sum_feat_sq, 49); + EXPECT_THAT(fs.values, testing::UnorderedElementsAre(2, 3, 6)); + EXPECT_EQ(fs.sum_feat_sq, 49); - VW::free_flatten_example(&flat); VW::finish_example(*vw, *ex); } @@ -44,10 +44,10 @@ TEST(FlatExample, EmptyExampleTest) auto vw = VW::initialize(vwtest::make_args("--quiet", "--noconstant")); auto* ex = VW::read_example(*vw, "1 |x a:0"); - auto& flat = *VW::flatten_sort_example(*vw, ex); + VW::features fs; + VW::flatten_features(*vw, *ex, fs); - EXPECT_TRUE(flat.fs.empty()); + EXPECT_TRUE(fs.empty()); - VW::free_flatten_example(&flat); VW::finish_example(*vw, *ex); }