Skip to content

Commit

Permalink
fix!: [LAS] las + squarecb to re-use squarecb gamma (#4479)
Browse files Browse the repository at this point in the history
  • Loading branch information
olgavrou authored Feb 1, 2023
1 parent c7dfdc2 commit 2bde23d
Show file tree
Hide file tree
Showing 12 changed files with 364 additions and 63 deletions.
39 changes: 39 additions & 0 deletions test/core.vwtest.json
Original file line number Diff line number Diff line change
Expand Up @@ -5766,5 +5766,44 @@
"depends_on": [
441, 443
]
},
{
"id": 446,
"desc": "large action spaces with cb_explore_adf epsilon greedy",
"vw_command": "--cb_explore_adf -d train-sets/las_100_actions.txt --noconstant --large_action_space --extra_metrics metrics_las_e.json",
"diff_files": {
"stderr": "train-sets/ref/las_egreedy.stderr",
"metrics_las_e.json": "test-sets/ref/metrics_las_e.json"
},
"input_files": [
"train-sets/las_100_actions.txt"
]
},
{
"id": 447,
"desc": "large action spaces with cb_explore_adf squarecb",
"vw_command": "--cb_explore_adf -d train-sets/las_100_actions.txt --squarecb --noconstant --large_action_space --extra_metrics metrics_las_sqcb.json",
"diff_files": {
"stderr": "train-sets/ref/las_sqcb.stderr",
"metrics_las_sqcb.json": "test-sets/ref/metrics_las_sqcb.json"
},
"input_files": [
"train-sets/las_100_actions.txt"
]
},
{
"id": 448,
"desc": "transition from squarecb model to las + squarecb",
"vw_command": "-d train-sets/cb_load.dat --cb_explore_adf -q UA --squarecb -i models/sqcb_ld.model --large_action_space --max_actions 2",
"diff_files": {
"stderr": "train-sets/ref/sqcb_to_las.stderr"
},
"input_files": [
"train-sets/cb_load.dat",
"models/sqcb_ld.model"
],
"depends_on": [
314
]
}
]
1 change: 1 addition & 0 deletions test/test-sets/ref/metrics_las_e.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cb_las_filtering_factor":15,"cbea_label_first_action":0,"cbea_label_not_first":2,"cbea_labeled_ex":2,"cbea_max_actions":99,"cbea_min_actions":99,"cbea_non_zero_cost":2,"cbea_predict_in_learn":0,"sfm_count_learn_example_with_shared":2,"total_learn_calls":2,"total_log_calls":0,"total_predict_calls":2,"cbea_avg_actions_per_event":99.0,"cbea_avg_feat_per_action":48.0,"cbea_avg_feat_per_event":4761.0,"cbea_avg_ns_per_action":2.0,"cbea_avg_ns_per_event":198.0,"cbea_sum_cost":-2.0,"cbea_sum_cost_baseline":0.0}
1 change: 1 addition & 0 deletions test/test-sets/ref/metrics_las_sqcb.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cb_las_filtering_factor":15,"cbea_label_first_action":0,"cbea_label_not_first":2,"cbea_labeled_ex":2,"cbea_max_actions":99,"cbea_min_actions":99,"cbea_non_zero_cost":2,"cbea_predict_in_learn":0,"sfm_count_learn_example_with_shared":2,"total_learn_calls":2,"total_log_calls":0,"total_predict_calls":2,"cbea_avg_actions_per_event":99.0,"cbea_avg_feat_per_action":48.0,"cbea_avg_feat_per_event":4761.0,"cbea_avg_ns_per_action":2.0,"cbea_avg_ns_per_event":198.0,"cbea_sum_cost":-2.0,"cbea_sum_cost_baseline":0.0}
202 changes: 202 additions & 0 deletions test/train-sets/las_100_actions.txt

Large diffs are not rendered by default.

22 changes: 22 additions & 0 deletions test/train-sets/ref/las_egreedy.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using no cache
Reading datafile = train-sets/las_100_actions.txt
num sources = 1
Num weight bits = 18
learning rate = 0.5
initial_t = 0
power_t = 0.5
cb_type = mtr
Enabled reductions: gd, scorer-identity, csoaa_ldf-rank, cb_adf, cb_explore_adf_large_action_space, cb_explore_adf_greedy, cb_actions_mask, shared_feature_merger, extra_metrics
Input label = CB
Output pred = ACTION_PROBS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 47:-1:0.17 0:0.06 4761
-0.00153 -0.00306 2 2.0 14:-1:0.96 65:0.95 4761

finished run
number of examples = 2
weighted example sum = 2.000000
weighted label sum = 0.000000
average loss = -0.001535
total feature number = 9522
22 changes: 22 additions & 0 deletions test/train-sets/ref/las_sqcb.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using no cache
Reading datafile = train-sets/las_100_actions.txt
num sources = 1
Num weight bits = 18
learning rate = 0.5
initial_t = 0
power_t = 0.5
cb_type = mtr
Enabled reductions: gd, scorer-identity, csoaa_ldf-rank, cb_adf, cb_explore_adf_large_action_space, cb_explore_adf_squarecb, cb_actions_mask, shared_feature_merger, extra_metrics
Input label = CB
Output pred = ACTION_PROBS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 47:-1:0.17 0:0.06 4761
-0.02926 -0.05853 2 2.0 14:-1:0.96 65:0.17 4761

finished run
number of examples = 2
weighted example sum = 2.000000
weighted label sum = 0.000000
average loss = -0.029265
total feature number = 9522
28 changes: 28 additions & 0 deletions test/train-sets/ref/sqcb_to_las.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
creating quadratic features for pairs: UA
using no cache
Reading datafile = train-sets/cb_load.dat
num sources = 1
Num weight bits = 18
learning rate = 0.5
initial_t = 113
power_t = 0.5
cb_type = mtr
Enabled reductions: gd, scorer-identity, csoaa_ldf-rank, cb_adf, cb_explore_adf_large_action_space, cb_explore_adf_squarecb, cb_actions_mask, shared_feature_merger
Input label = CB
Output pred = ACTION_PROBS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 4:0:0.14 0:0.97 42
0.000000 0.000000 2 2.0 0:0:0.14 0:0.94 42
0.000000 0.000000 4 4.0 1:0:0.14 2:0.97 42
0.000000 0.000000 8 8.0 3:0:0.14 2:0.87 42
0.000000 0.000000 16 16.0 3:0:0.14 2:0.88 42
0.000000 0.000000 32 32.0 3:0:0.14 2:0.98 42
-0.58935 -1.17870 64 64.0 0:0:0.77 0:0.97 42

finished run
number of examples = 113
weighted example sum = 113.000000
weighted label sum = 0.000000
average loss = -0.581254
total feature number = 4746
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class las_reduction_features
std::vector<std::vector<VW::namespace_index>>* generated_interactions = nullptr;
std::vector<std::vector<extent_term>>* generated_extent_interactions = nullptr;
VW::multi_ex::value_type shared_example = nullptr;
float squarecb_gamma = 1.f;

las_reduction_features() = default;

Expand All @@ -28,6 +29,7 @@ class las_reduction_features
generated_interactions = nullptr;
generated_extent_interactions = nullptr;
shared_example = nullptr;
squarecb_gamma = 1.f;
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,6 @@ bool _test_only_generate_A(VW::workspace* _all, const multi_ex& examples, std::v
return (_A.cols() != 0 && _A.rows() != 0);
}

template <typename randomized_svd_impl, typename spanner_impl>
void cb_explore_adf_large_action_space<randomized_svd_impl, spanner_impl>::save_load(io_buf& io, bool read, bool text)
{
if (io.num_files() == 0) { return; }

if (read) { model_utils::read_model_field(io, _counter); }
else { model_utils::write_model_field(io, _counter, "cb large action space storing example counter", text); }
}

template <typename randomized_svd_impl, typename spanner_impl>
void cb_explore_adf_large_action_space<randomized_svd_impl, spanner_impl>::randomized_SVD(const multi_ex& examples)
{
Expand Down Expand Up @@ -157,7 +148,10 @@ void cb_explore_adf_large_action_space<randomized_svd_impl, spanner_impl>::updat

if (_d < preds.size())
{
shrink_fact_config.calculate_shrink_factor(_counter, _d, preds, shrink_factors);
auto& red_features =
examples[0]->ex_reduction_features.template get<VW::large_action_space::las_reduction_features>();

shrink_fact_config.calculate_shrink_factor(red_features.squarecb_gamma, _d, preds, shrink_factors);
randomized_SVD(examples);

// The U matrix is empty before learning anything.
Expand Down Expand Up @@ -205,9 +199,9 @@ template <typename randomized_svd_impl, typename spanner_impl>
void cb_explore_adf_large_action_space<randomized_svd_impl, spanner_impl>::learn(
VW::LEARNER::multi_learner& base, multi_ex& examples)
{
VW::v_array<VW::action_score> preds = std::move(examples[0]->pred.a_s);
auto restore_guard = VW::scope_exit([&preds, &examples] { examples[0]->pred.a_s = std::move(preds); });
base.learn(examples);
if (base.learn_returns_prediction) { update_example_prediction(examples); }
++_counter;
}

void generate_Z(const multi_ex& examples, Eigen::MatrixXf& Z, Eigen::MatrixXf& B, uint64_t d, uint64_t seed)
Expand Down Expand Up @@ -235,33 +229,28 @@ void generate_Z(const multi_ex& examples, Eigen::MatrixXf& Z, Eigen::MatrixXf& B
}

template <typename T, typename S>
cb_explore_adf_large_action_space<T, S>::cb_explore_adf_large_action_space(uint64_t d, float gamma_scale,
float gamma_exponent, float c, bool apply_shrink_factor, VW::workspace* all, uint64_t seed, size_t total_size,
size_t thread_pool_size, size_t block_size, bool use_explicit_simd, implementation_type impl_type)
cb_explore_adf_large_action_space<T, S>::cb_explore_adf_large_action_space(uint64_t d, float c,
bool apply_shrink_factor, VW::workspace* all, uint64_t seed, size_t total_size, size_t thread_pool_size,
size_t block_size, bool use_explicit_simd, implementation_type impl_type)
: _d(d)
, _all(all)
, _counter(0)
, _seed(seed)
, _impl_type(impl_type)
, spanner_state(c, d)
, shrink_fact_config(gamma_scale, gamma_exponent, apply_shrink_factor)
, shrink_fact_config(apply_shrink_factor)
, impl(all, d, _seed, total_size, thread_pool_size, block_size, use_explicit_simd)
{
}

shrink_factor_config::shrink_factor_config(float gamma_scale, float gamma_exponent, bool apply_shrink_factor)
: _gamma_scale(gamma_scale), _gamma_exponent(gamma_exponent), _apply_shrink_factor(apply_shrink_factor)
{
}
shrink_factor_config::shrink_factor_config(bool apply_shrink_factor) : _apply_shrink_factor(apply_shrink_factor) {}

void shrink_factor_config::calculate_shrink_factor(
size_t counter, size_t max_actions, const VW::action_scores& preds, std::vector<float>& shrink_factors)
float gamma, size_t max_actions, const VW::action_scores& preds, std::vector<float>& shrink_factors)
{
if (_apply_shrink_factor)
{
shrink_factors.clear();
float min_ck = std::min_element(preds.begin(), preds.end())->score;
float gamma = _gamma_scale * static_cast<float>(std::pow(counter, _gamma_exponent));
for (size_t i = 0; i < preds.size(); i++)
{
shrink_factors.push_back(std::sqrt(1 + max_actions + gamma / (4.0f * max_actions) * (preds[i].score - min_ck)));
Expand All @@ -283,12 +272,6 @@ void persist_metrics(cb_explore_adf_large_action_space<T, S>& data, VW::metric_s
metrics.set_uint("cb_las_filtering_factor", data.number_of_non_degenerate_singular_values());
}

template <typename T, typename S>
void save_load(cb_explore_adf_large_action_space<T, S>& data, VW::io_buf& io, bool read, bool text)
{
data.save_load(io, read, text);
}

template <typename T, typename S>
void predict(cb_explore_adf_large_action_space<T, S>& data, VW::LEARNER::multi_learner& base, VW::multi_ex& examples)
{
Expand All @@ -303,15 +286,15 @@ void learn(cb_explore_adf_large_action_space<T, S>& data, VW::LEARNER::multi_lea

template <typename T, typename S>
VW::LEARNER::base_learner* make_las_with_impl(VW::setup_base_i& stack_builder, VW::LEARNER::multi_learner* base,
implementation_type& impl_type, VW::workspace& all, uint64_t d, float gamma_scale, float gamma_exponent, float c,
bool apply_shrink_factor, size_t thread_pool_size, size_t block_size, bool use_explicit_simd)
implementation_type& impl_type, VW::workspace& all, uint64_t d, float c, bool apply_shrink_factor,
size_t thread_pool_size, size_t block_size, bool use_explicit_simd)
{
size_t problem_multiplier = 1;

float seed = (all.get_random_state()->get_random() + 1) * 10.f;

auto data = VW::make_unique<cb_explore_adf_large_action_space<T, S>>(d, gamma_scale, gamma_exponent, c,
apply_shrink_factor, &all, seed, 1 << all.num_bits, thread_pool_size, block_size, use_explicit_simd, impl_type);
auto data = VW::make_unique<cb_explore_adf_large_action_space<T, S>>(d, c, apply_shrink_factor, &all, seed,
1 << all.num_bits, thread_pool_size, block_size, use_explicit_simd, impl_type);

auto* l = make_reduction_learner(std::move(data), base, learn<T, S>, predict<T, S>,
stack_builder.get_setupfn_name(VW::reductions::cb_explore_adf_large_action_space_setup))
Expand All @@ -321,8 +304,7 @@ VW::LEARNER::base_learner* make_las_with_impl(VW::setup_base_i& stack_builder, V
.set_output_prediction_type(VW::prediction_type_t::ACTION_SCORES)
.set_params_per_weight(problem_multiplier)
.set_persist_metrics(persist_metrics<T, S>)
.set_save_load(save_load<T, S>)
.set_learn_returns_prediction(base->learn_returns_prediction)
.set_learn_returns_prediction(false)
.build();
return VW::LEARNER::make_base(*l);
}
Expand All @@ -336,8 +318,6 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_large_action_space_set
bool cb_explore_adf_option = false;
bool large_action_space = false;
uint64_t d;
float gamma_scale = 1.f;
float gamma_exponent = 0.f;
float c;
bool apply_shrink_factor = false;
bool use_two_pass_svd_impl = false;
Expand Down Expand Up @@ -389,12 +369,7 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_large_action_space_set
auto enabled = options.add_parse_and_check_necessary(new_options) && large_action_space;
if (!enabled) { return nullptr; }

if (options.was_supplied("squarecb"))
{
apply_shrink_factor = true;
gamma_scale = options.get_typed_option<float>("gamma_scale").value();
gamma_exponent = options.get_typed_option<float>("gamma_exponent").value();
}
if (options.was_supplied("squarecb")) { apply_shrink_factor = true; }

if (options.was_supplied("cb_type"))
{
Expand All @@ -414,15 +389,14 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_large_action_space_set
if (use_two_pass_svd_impl)
{
auto impl_type = implementation_type::two_pass_svd;
return make_las_with_impl<two_pass_svd_impl, one_rank_spanner_state>(stack_builder, base, impl_type, all, d,
gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size, block_size,
return make_las_with_impl<two_pass_svd_impl, one_rank_spanner_state>(stack_builder, base, impl_type, all, d, c,
apply_shrink_factor, thread_pool_size, block_size,
/*use_explicit_simd=*/false);
}
else
{
auto impl_type = implementation_type::one_pass_svd;
return make_las_with_impl<one_pass_svd_impl, one_rank_spanner_state>(stack_builder, base, impl_type, all, d,
gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size, block_size,
use_simd_in_one_pass_svd_impl);
return make_las_with_impl<one_pass_svd_impl, one_rank_spanner_state>(stack_builder, base, impl_type, all, d, c,
apply_shrink_factor, thread_pool_size, block_size, use_simd_in_one_pass_svd_impl);
}
}
25 changes: 19 additions & 6 deletions vowpalwabbit/core/src/reductions/cb/cb_explore_adf_squarecb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class cb_explore_adf_squarecb
{
public:
cb_explore_adf_squarecb(float gamma_scale, float gamma_exponent, bool elim, float c0, float min_cb_cost,
float max_cb_cost, VW::version_struct model_file_version, float epsilon);
float max_cb_cost, VW::version_struct model_file_version, float epsilon, bool store_gamma_in_reduction_features);
~cb_explore_adf_squarecb() = default;

// Should be called through cb_explore_adf_base for pre/post-processing
Expand All @@ -68,6 +68,8 @@ class cb_explore_adf_squarecb

VW::version_struct _model_file_version;

bool _store_gamma_in_reduction_features;

// for backing up cb example data when computing sensitivities
std::vector<VW::action_scores> _ex_as;
std::vector<std::vector<VW::cb_class>> _ex_costs;
Expand All @@ -76,7 +78,8 @@ class cb_explore_adf_squarecb
};

cb_explore_adf_squarecb::cb_explore_adf_squarecb(float gamma_scale, float gamma_exponent, bool elim, float c0,
float min_cb_cost, float max_cb_cost, VW::version_struct model_file_version, float epsilon)
float min_cb_cost, float max_cb_cost, VW::version_struct model_file_version, float epsilon,
bool store_gamma_in_reduction_features)
: _counter(0)
, _gamma_scale(gamma_scale)
, _gamma_exponent(gamma_exponent)
Expand All @@ -86,6 +89,7 @@ cb_explore_adf_squarecb::cb_explore_adf_squarecb(float gamma_scale, float gamma_
, _max_cb_cost(max_cb_cost)
, _epsilon(epsilon)
, _model_file_version(model_file_version)
, _store_gamma_in_reduction_features(store_gamma_in_reduction_features)
{
}

Expand Down Expand Up @@ -186,14 +190,20 @@ void cb_explore_adf_squarecb::get_cost_ranges(float delta, multi_learner& base,

void cb_explore_adf_squarecb::predict(multi_learner& base, VW::multi_ex& examples)
{
// The actual parameter $\gamma$ used in the SquareCB.
const float gamma = _gamma_scale * static_cast<float>(std::pow(_counter, _gamma_exponent));
if (_store_gamma_in_reduction_features)
{
auto& red_features =
examples[0]->ex_reduction_features.template get<VW::large_action_space::las_reduction_features>();
red_features.squarecb_gamma = gamma;
}

multiline_learn_or_predict<false>(base, examples, examples[0]->ft_offset);

VW::v_array<VW::action_score>& preds = examples[0]->pred.a_s;
uint32_t num_actions = static_cast<uint32_t>(preds.size());

// The actual parameter $\gamma$ used in the SquareCB.
const float gamma = _gamma_scale * static_cast<float>(std::pow(_counter, _gamma_exponent));

// RegCB action set parameters
const float max_range = _max_cb_cost - _min_cb_cost;
// threshold on empirical loss difference
Expand Down Expand Up @@ -374,6 +384,9 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_squarecb_setup(VW::set
// Ensure serialization of cb_adf in all cases.
if (!options.was_supplied("cb_adf")) { options.insert("cb_adf", ""); }

bool store_gamma_in_reduction_features = false;
if (options.was_supplied("large_action_space")) { store_gamma_in_reduction_features = true; }

// Set explore_type
size_t problem_multiplier = 1;

Expand All @@ -384,7 +397,7 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_squarecb_setup(VW::set

using explore_type = cb_explore_adf_base<cb_explore_adf_squarecb>;
auto data = VW::make_unique<explore_type>(all.global_metrics.are_metrics_enabled(), gamma_scale, gamma_exponent, elim,
c0, min_cb_cost, max_cb_cost, all.model_file_ver, epsilon);
c0, min_cb_cost, max_cb_cost, all.model_file_ver, epsilon, store_gamma_in_reduction_features);
auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
stack_builder.get_setupfn_name(cb_explore_adf_squarecb_setup))
.set_input_label_type(VW::label_type_t::CB)
Expand Down
Loading

0 comments on commit 2bde23d

Please sign in to comment.