Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: make flat_example an implementation detail of ksvm #4505

Merged
merged 8 commits into from
Mar 3, 2023
41 changes: 2 additions & 39 deletions vowpalwabbit/core/include/vw/core/example.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,26 +143,8 @@ class example : public example_predict // core example datatype.

class workspace;

class flat_example
{
public:
polylabel l;
reduction_features ex_reduction_features;

VW::v_array<char> tag; // An identifier for the example.

size_t example_counter;
uint64_t ft_offset;
float global_weight;

size_t num_features; // precomputed, cause it's fast&easy.
float total_sum_feat_sq; // precomputed, cause it's kind of fast & easy.
features fs; // all the features
};

flat_example* flatten_example(VW::workspace& all, example* ec);
flat_example* flatten_sort_example(VW::workspace& all, example* ec);
void free_flatten_example(flat_example* fec);
// TODO: make workspace and example const
void flatten_features(VW::workspace& all, example& ec, features& fs);

inline bool example_is_newline(const example& ec) { return ec.is_newline; }

Expand Down Expand Up @@ -194,13 +176,6 @@ void truncate_example_namespace(VW::example& ec, VW::namespace_index ns, const f
void append_example_namespaces_from_example(VW::example& target, const VW::example& source);
void truncate_example_namespaces_from_example(VW::example& target, const VW::example& source);
} // namespace details

namespace model_utils
{
size_t read_model_field(io_buf& io, flat_example& fe, VW::label_parser& lbl_parser);
size_t write_model_field(io_buf& io, const flat_example& fe, const std::string& upstream_name, bool text,
VW::label_parser& lbl_parser, uint64_t parse_mask);
} // namespace model_utils
} // namespace VW

// Deprecated compat definitions
Expand All @@ -209,18 +184,6 @@ using polylabel VW_DEPRECATED("polylabel moved into VW namespace") = VW::polylab
using polyprediction VW_DEPRECATED("polyprediction moved into VW namespace") = VW::polyprediction;
using example VW_DEPRECATED("example moved into VW namespace") = VW::example;
using multi_ex VW_DEPRECATED("multi_ex moved into VW namespace") = VW::multi_ex;
using flat_example VW_DEPRECATED("flat_example moved into VW namespace") = VW::flat_example;

VW_DEPRECATED("flatten_example moved into VW namespace")
inline VW::flat_example* flatten_example(VW::workspace& all, VW::example* ec) { return VW::flatten_example(all, ec); }

VW_DEPRECATED("flatten_sort_example moved into VW namespace")
inline VW::flat_example* flatten_sort_example(VW::workspace& all, VW::example* ec)
{
return VW::flatten_sort_example(all, ec);
}
VW_DEPRECATED("free_flatten_example moved into VW namespace")
inline void free_flatten_example(VW::flat_example* fec) { return VW::free_flatten_example(fec); }

VW_DEPRECATED("example_is_newline moved into VW namespace")
inline bool example_is_newline(const VW::example& ec) { return VW::example_is_newline(ec); }
Expand Down
4 changes: 4 additions & 0 deletions vowpalwabbit/core/include/vw/core/feature_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,10 @@ class features
return all_extents_complete;
}
};

/// Both fs1 and fs2 must be sorted.
/// Most often used with VW::flatten_features
float features_dot_product(const features& fs1, const features& fs2);
} // namespace VW

using feature_value VW_DEPRECATED("Moved into VW namespace. Will be removed in VW 10.") = VW::feature_value;
Expand Down
82 changes: 13 additions & 69 deletions vowpalwabbit/core/src/example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ float VW::example::get_total_sum_feat_sq()

float collision_cleanup(VW::features& fs)
{
// Input must be sorted.
assert(std::is_sorted(fs.indices.begin(), fs.indices.end()));

// This loops over the sequence of feature values and their indexes
// when an index is repeated this combines them by adding their values.
// This assumes that fs is sorted (which is the case in `flatten_sort_example`).
Expand Down Expand Up @@ -105,46 +108,23 @@ void vec_ffs_store(full_features_and_source& p, float fx, uint64_t fi)
}
namespace VW
{
flat_example* flatten_example(VW::workspace& all, example* ec)
{
flat_example& fec = VW::details::calloc_or_throw<flat_example>();
fec.l = ec->l;
fec.tag = ec->tag;
fec.ex_reduction_features = ec->ex_reduction_features;
fec.example_counter = ec->example_counter;
fec.ft_offset = ec->ft_offset;
fec.num_features = ec->num_features;

void flatten_features(VW::workspace& all, example& ec, features& fs)
{
fs.clear();
full_features_and_source ffs;
ffs.fs = std::move(fs);
ffs.stride_shift = all.weights.stride_shift();
if (all.weights.not_null())
{ // TODO:temporary fix. all.weights is not initialized at this point in some cases.
{
// TODO:temporary fix. all.weights is not initialized at this point in some cases.
ffs.mask = all.weights.mask() >> all.weights.stride_shift();
}
else { ffs.mask = static_cast<uint64_t>(LONG_MAX) >> all.weights.stride_shift(); }
VW::foreach_feature<full_features_and_source, uint64_t, vec_ffs_store>(all, *ec, ffs);

std::swap(fec.fs, ffs.fs);

return &fec;
}

flat_example* flatten_sort_example(VW::workspace& all, example* ec)
{
flat_example* fec = flatten_example(all, ec);
fec->fs.sort(all.parse_mask);
fec->total_sum_feat_sq = collision_cleanup(fec->fs);
return fec;
}

void free_flatten_example(flat_example* fec)
{
// note: The label memory should be freed by by freeing the original example.
if (fec)
{
fec->fs.~features();
free(fec);
}
VW::foreach_feature<full_features_and_source, uint64_t, vec_ffs_store>(all, ec, ffs);
ffs.fs.sort(all.parse_mask);
ffs.fs.sum_feat_sq = collision_cleanup(ffs.fs);
fs = std::move(ffs.fs);
}

void return_multiple_example(VW::workspace& all, VW::multi_ex& examples)
Expand Down Expand Up @@ -213,42 +193,6 @@ void truncate_example_namespaces_from_example(VW::example& target, const VW::exa
}
}
} // namespace details

namespace model_utils
{
size_t read_model_field(io_buf& io, flat_example& fe, VW::label_parser& lbl_parser)
{
size_t bytes = 0;
lbl_parser.default_label(fe.l);
bytes += lbl_parser.read_cached_label(fe.l, fe.ex_reduction_features, io);
bytes += read_model_field(io, fe.tag);
bytes += read_model_field(io, fe.example_counter);
bytes += read_model_field(io, fe.ft_offset);
bytes += read_model_field(io, fe.global_weight);
bytes += read_model_field(io, fe.num_features);
bytes += read_model_field(io, fe.total_sum_feat_sq);
unsigned char index = 0;
bytes += ::VW::parsers::cache::details::read_cached_index(io, index);
bool sorted = true;
bytes += ::VW::parsers::cache::details::read_cached_features(io, fe.fs, sorted);
return bytes;
}
size_t write_model_field(io_buf& io, const flat_example& fe, const std::string& upstream_name, bool text,
VW::label_parser& lbl_parser, uint64_t parse_mask)
{
size_t bytes = 0;
lbl_parser.cache_label(fe.l, fe.ex_reduction_features, io, upstream_name + "_label", text);
bytes += write_model_field(io, fe.tag, upstream_name + "_tag", text);
bytes += write_model_field(io, fe.example_counter, upstream_name + "_example_counter", text);
bytes += write_model_field(io, fe.ft_offset, upstream_name + "_ft_offset", text);
bytes += write_model_field(io, fe.global_weight, upstream_name + "_global_weight", text);
bytes += write_model_field(io, fe.num_features, upstream_name + "_num_features", text);
bytes += write_model_field(io, fe.total_sum_feat_sq, upstream_name + "_total_sum_feat_sq", text);
::VW::parsers::cache::details::cache_index(io, 0);
::VW::parsers::cache::details::cache_features(io, fe.fs, parse_mask);
return bytes;
}
} // namespace model_utils
} // namespace VW

namespace VW
Expand Down
25 changes: 25 additions & 0 deletions vowpalwabbit/core/src/feature_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,28 @@ void VW::features::end_ns_extent()
}
}
}

float VW::features_dot_product(const features& fs1, const features& fs2)
{
assert(std::is_sorted(fs1.indices.begin(), fs1.indices.end()));
assert(std::is_sorted(fs2.indices.begin(), fs2.indices.end()));

float dotprod = 0;
if (fs2.indices.empty()) { return 0.f; }

for (size_t idx1 = 0, idx2 = 0; idx1 < fs1.size() && idx2 < fs2.size(); idx1++)
{
uint64_t ec1pos = fs1.indices[idx1];
uint64_t ec2pos = fs2.indices[idx2];
if (ec1pos < ec2pos) { continue; }

while (ec1pos > ec2pos && ++idx2 < fs2.size()) { ec2pos = fs2.indices[idx2]; }

if (ec1pos == ec2pos)
{
dotprod += fs1.values[idx1] * fs2.values[idx2];
++idx2;
}
}
return dotprod;
}
13 changes: 7 additions & 6 deletions vowpalwabbit/core/src/reductions/eigen_memory_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "vw/config/options.h"
#include "vw/core/array_parameters.h"
#include "vw/core/example.h"
#include "vw/core/feature_group.h"
#include "vw/core/learner.h"
#include "vw/core/memory.h"
#include "vw/core/model_utils.h"
Expand Down Expand Up @@ -75,14 +76,14 @@ emt_example::emt_example(VW::workspace& all, VW::example* ex)
std::vector<std::vector<VW::namespace_index>> base_interactions;

ex->interactions = &base_interactions;
auto* ex1 = VW::flatten_sort_example(all, ex);
for (auto& f : ex1->fs) { base.emplace_back(f.index(), f.value()); }
VW::free_flatten_example(ex1);
VW::features fs;
VW::flatten_features(all, *ex, fs);
for (auto& f : fs) { base.emplace_back(f.index(), f.value()); }

fs.clear();
ex->interactions = full_interactions;
auto* ex2 = VW::flatten_sort_example(all, ex);
for (auto& f : ex2->fs) { full.emplace_back(f.index(), f.value()); }
VW::free_flatten_example(ex2);
VW::flatten_features(all, *ex, fs);
for (auto& f : fs) { full.emplace_back(f.index(), f.value()); }
}

emt_lru::emt_lru(uint64_t max_size) : max_size(max_size) {}
Expand Down
Loading