Skip to content

Commit

Permalink
feat: Improve large actions multithreading. (#4158)
Browse files Browse the repository at this point in the history
* Partition work to blocks.

* Update cb_las benchmarks.

* Add block_size parameter.

* Update benchmarks.

* Fix test and format.

Co-authored-by: olgavrou <[email protected]>
  • Loading branch information
zwd-ms and olgavrou authored Sep 27, 2022
1 parent eb5c59e commit 60b2580
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 57 deletions.
56 changes: 49 additions & 7 deletions test/benchmarks/standalone/benchmark_text_input.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,25 +277,25 @@ BENCHMARK_CAPTURE(benchmark_multi, ccb_adf_same_char_no_interactions,
BENCHMARK_CAPTURE(benchmark_multi, ccb_adf_same_char_interactions, gen_ccb_examples(50, 7, 3, 6, 3, 4, 14, 2, true, 3),
"--ccb_explore_adf --quiet -q ::")
->MinTime(15.0);
#ifdef BUILD_LARGE_ACTION_SPACE

BENCHMARK_CAPTURE(benchmark_multi_predict, cb_las_medium_300_onestep,
gen_cb_examples(1, 50, 10, 300, 1, 1, 20, 10, false),
#ifdef BUILD_LARGE_ACTION_SPACE
BENCHMARK_CAPTURE(benchmark_multi_predict, cb_las_small_300_onestep,
gen_cb_examples(1, 50, 10, 300, 5, 5, 20, 10, false),
"--cb_explore_adf --large_action_space -q :: --max_actions 20 --quiet")
->MinTime(15.0)
->UseRealTime()
->Unit(benchmark::kMillisecond);

BENCHMARK_CAPTURE(benchmark_multi_predict, cb_las_medium_300_onestep_max_threads,
gen_cb_examples(1, 50, 10, 300, 1, 1, 20, 10, false),
BENCHMARK_CAPTURE(benchmark_multi_predict, cb_las_small_300_onestep_max_threads,
gen_cb_examples(1, 50, 10, 300, 5, 5, 20, 10, false),
"--cb_explore_adf --large_action_space -q :: --max_actions 20 --quiet --thread_pool_size " +
std::to_string(std::thread::hardware_concurrency()))
->MinTime(15.0)
->UseRealTime()
->Unit(benchmark::kMillisecond);

BENCHMARK_CAPTURE(benchmark_multi_predict, cb_las_medium_300_plaincb,
gen_cb_examples(1, 50, 10, 300, 1, 1, 20, 10, false), "--cb_explore_adf -q :: --quiet")
BENCHMARK_CAPTURE(benchmark_multi_predict, cb_las_small_300_plaincb,
gen_cb_examples(1, 50, 10, 300, 5, 5, 20, 10, false), "--cb_explore_adf -q :: --quiet")
->MinTime(15.0)
->UseRealTime()
->Unit(benchmark::kMillisecond);
Expand Down Expand Up @@ -342,6 +342,27 @@ BENCHMARK_CAPTURE(benchmark_multi_predict, cb_las_small_1k_plaincb,
->UseRealTime()
->Unit(benchmark::kMillisecond);

BENCHMARK_CAPTURE(benchmark_multi_predict, cb_las_medium_300_onestep,
gen_cb_examples(1, 50, 20, 300, 5, 5, 20, 10, false),
"--cb_explore_adf --large_action_space -q :: --max_actions 20 --quiet")
->MinTime(15.0)
->UseRealTime()
->Unit(benchmark::kMillisecond);

BENCHMARK_CAPTURE(benchmark_multi_predict, cb_las_medium_300_onestep_max_threads,
gen_cb_examples(1, 50, 20, 300, 5, 5, 20, 10, false),
"--cb_explore_adf --large_action_space -q :: --max_actions 20 --quiet --thread_pool_size " +
std::to_string(std::thread::hardware_concurrency()))
->MinTime(15.0)
->UseRealTime()
->Unit(benchmark::kMillisecond);

BENCHMARK_CAPTURE(benchmark_multi_predict, cb_las_medium_300_plaincb,
gen_cb_examples(1, 50, 20, 300, 5, 5, 20, 10, false), "--cb_explore_adf -q :: --quiet")
->MinTime(15.0)
->UseRealTime()
->Unit(benchmark::kMillisecond);

BENCHMARK_CAPTURE(benchmark_multi_predict, cb_las_medium_500_onestep,
gen_cb_examples(1, 50, 20, 500, 5, 5, 20, 10, false),
"--cb_explore_adf --large_action_space -q :: --max_actions 20 --quiet")
Expand Down Expand Up @@ -384,6 +405,27 @@ BENCHMARK_CAPTURE(benchmark_multi_predict, cb_las_medium_1k_plaincb,
->UseRealTime()
->Unit(benchmark::kMillisecond);

BENCHMARK_CAPTURE(benchmark_multi_predict, cb_las_large_300_onestep,
gen_cb_examples(1, 50, 50, 300, 5, 5, 20, 10, false),
"--cb_explore_adf --large_action_space -q :: --max_actions 20 --quiet")
->MinTime(15.0)
->UseRealTime()
->Unit(benchmark::kMillisecond);

BENCHMARK_CAPTURE(benchmark_multi_predict, cb_las_large_300_onestep_max_threads,
gen_cb_examples(1, 50, 50, 300, 5, 5, 20, 10, false),
"--cb_explore_adf --large_action_space -q :: --max_actions 20 --quiet --thread_pool_size " +
std::to_string(std::thread::hardware_concurrency()))
->MinTime(15.0)
->UseRealTime()
->Unit(benchmark::kMillisecond);

BENCHMARK_CAPTURE(benchmark_multi_predict, cb_las_large_300_plaincb,
gen_cb_examples(1, 50, 50, 300, 5, 5, 20, 10, false), "--cb_explore_adf -q :: --quiet")
->MinTime(15.0)
->UseRealTime()
->Unit(benchmark::kMillisecond);

BENCHMARK_CAPTURE(benchmark_multi_predict, cb_las_large_500_onestep,
gen_cb_examples(1, 50, 50, 500, 5, 5, 20, 10, false),
"--cb_explore_adf --large_action_space -q :: --max_actions 20 --quiet")
Expand Down
4 changes: 2 additions & 2 deletions test/unit_test/cb_large_actions_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ BOOST_AUTO_TEST_CASE(test_two_Ys_are_equal)
BOOST_CHECK_EQUAL(action_space != nullptr, true);

VW::cb_explore_adf::model_weight_rand_svd_impl _model_weight_rand_svd_impl(
&vw, d, 50, 1 << vw.num_bits, /*thread_pool_size*/ 0);
&vw, d, 50, 1 << vw.num_bits, /*thread_pool_size*/ 0, /*block_size*/ 0);

{
VW::multi_ex examples;
Expand Down Expand Up @@ -157,7 +157,7 @@ BOOST_AUTO_TEST_CASE(test_two_Bs_are_equal)
BOOST_CHECK_EQUAL(action_space != nullptr, true);

VW::cb_explore_adf::model_weight_rand_svd_impl _model_weight_rand_svd_impl(
&vw, d, 50, 1 << vw.num_bits, /*thread_pool_size*/ 0);
&vw, d, 50, 1 << vw.num_bits, /*thread_pool_size*/ 0, /*block_size*/ 0);

{
VW::multi_ex examples;
Expand Down
2 changes: 1 addition & 1 deletion test/unit_test/cb_las_spanner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ BOOST_AUTO_TEST_CASE(check_finding_max_volume)
VW::cb_explore_adf::one_rank_spanner_state>
largecb(
/*d=*/0, /*gamma_scale=*/1.f, /*gamma_exponent=*/0.f, /*c=*/2, false, &vw, seed, 1 << vw.num_bits,
/*thread_pool_size*/ 0, VW::cb_explore_adf::implementation_type::one_pass_svd);
/*thread_pool_size*/ 0, /*block_size*/ 0, VW::cb_explore_adf::implementation_type::one_pass_svd);
largecb.U = Eigen::MatrixXf{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {0, 0, 0}, {7, 5, 3}, {6, 4, 8}};
Eigen::MatrixXf X{{1, 2, 3}, {3, 2, 1}, {2, 1, 3}};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,15 +243,15 @@ void generate_Z(const multi_ex& examples, Eigen::MatrixXf& Z, Eigen::MatrixXf& B
template <typename T, typename S>
cb_explore_adf_large_action_space<T, S>::cb_explore_adf_large_action_space(uint64_t d, float gamma_scale,
float gamma_exponent, float c, bool apply_shrink_factor, VW::workspace* all, uint64_t seed, size_t total_size,
size_t thread_pool_size, implementation_type impl_type)
size_t thread_pool_size, size_t block_size, implementation_type impl_type)
: _d(d)
, _all(all)
, _counter(0)
, _seed(seed)
, _impl_type(impl_type)
, spanner_state(c, d)
, shrink_fact_config(gamma_scale, gamma_exponent, apply_shrink_factor)
, impl(all, d, _seed, total_size, thread_pool_size)
, impl(all, d, _seed, total_size, thread_pool_size, block_size)
{
}

Expand Down Expand Up @@ -288,7 +288,7 @@ template struct cb_explore_adf_large_action_space<model_weight_rand_svd_impl, on
template <typename T, typename S>
VW::LEARNER::base_learner* make_las_with_impl(VW::setup_base_i& stack_builder, VW::LEARNER::multi_learner* base,
implementation_type& impl_type, VW::workspace& all, bool with_metrics, uint64_t d, float gamma_scale,
float gamma_exponent, float c, bool apply_shrink_factor, size_t thread_pool_size)
float gamma_exponent, float c, bool apply_shrink_factor, size_t thread_pool_size, size_t block_size)
{
using explore_type = cb_explore_adf_base<cb_explore_adf_large_action_space<T, S>>;

Expand All @@ -297,7 +297,7 @@ VW::LEARNER::base_learner* make_las_with_impl(VW::setup_base_i& stack_builder, V
uint64_t seed = all.get_random_state()->get_current_state() * 10.f;

auto data = VW::make_unique<explore_type>(with_metrics, d, gamma_scale, gamma_exponent, c, apply_shrink_factor, &all,
seed, 1 << all.num_bits, thread_pool_size, impl_type);
seed, 1 << all.num_bits, thread_pool_size, block_size, impl_type);

auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
stack_builder.get_setupfn_name(VW::reductions::cb_explore_adf_large_action_space_setup))
Expand Down Expand Up @@ -331,6 +331,7 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_large_action_space_set
bool use_vanilla_impl = false;
bool full_spanner = false;
size_t thread_pool_size = 0;
size_t block_size = 0;

config::option_group_definition new_options(
"[Reduction] Experimental: Contextual Bandit Exploration with ADF with large action space");
Expand All @@ -346,6 +347,10 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_large_action_space_set
.default_value(0)
.help("number of threads in the thread pool that will be used when running with one pass svd "
"implementation (default svd implementation option)"))
.add(make_option("block_size", block_size)
.default_value(0)
.help("number of actions in a block to be scheduled for multithreading when using one pass svd "
"implementation (by default, block_size = num_actions / thread_pool_size)"))
.add(make_option("large_action_space", large_action_space)
.necessary()
.keep()
Expand Down Expand Up @@ -400,12 +405,12 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_large_action_space_set
if (full_spanner)
{
return make_las_with_impl<model_weight_rand_svd_impl, spanner_state>(stack_builder, base, impl_type, all,
with_metrics, d, gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size);
with_metrics, d, gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size, block_size);
}
else
{
return make_las_with_impl<model_weight_rand_svd_impl, one_rank_spanner_state>(stack_builder, base, impl_type, all,
with_metrics, d, gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size);
with_metrics, d, gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size, block_size);
}
}
else if (use_vanilla_impl)
Expand All @@ -414,12 +419,12 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_large_action_space_set
if (full_spanner)
{
return make_las_with_impl<vanilla_rand_svd_impl, spanner_state>(stack_builder, base, impl_type, all, with_metrics,
d, gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size);
d, gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size, block_size);
}
else
{
return make_las_with_impl<vanilla_rand_svd_impl, one_rank_spanner_state>(stack_builder, base, impl_type, all,
with_metrics, d, gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size);
with_metrics, d, gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size, block_size);
}
}
else
Expand All @@ -428,12 +433,12 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_large_action_space_set
if (full_spanner)
{
return make_las_with_impl<one_pass_svd_impl, spanner_state>(stack_builder, base, impl_type, all, with_metrics, d,
gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size);
gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size, block_size);
}
else
{
return make_las_with_impl<one_pass_svd_impl, one_rank_spanner_state>(stack_builder, base, impl_type, all,
with_metrics, d, gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size);
with_metrics, d, gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size, block_size);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ void model_weight_rand_svd_impl::run(const multi_ex& examples, const std::vector
}

model_weight_rand_svd_impl::model_weight_rand_svd_impl(
VW::workspace* all, uint64_t d, uint64_t seed, size_t total_size, size_t)
VW::workspace* all, uint64_t d, uint64_t seed, size_t total_size, size_t, size_t)
: _all(all), _d(d), _seed(seed), _internal_weights(total_size)
{
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,13 @@ struct AO_triplet_constructor
{
private:
uint64_t _weights_mask;
uint64_t _row_index;
uint64_t _column_index;
uint64_t _seed;
float& _final_dot_product;

public:
AO_triplet_constructor(
uint64_t weights_mask, uint64_t row_index, uint64_t column_index, uint64_t seed, float& final_dot_product)
: _weights_mask(weights_mask)
, _row_index(row_index)
, _column_index(column_index)
, _seed(seed)
, _final_dot_product(final_dot_product)
AO_triplet_constructor(uint64_t weights_mask, uint64_t column_index, uint64_t seed, float& final_dot_product)
: _weights_mask(weights_mask), _column_index(column_index), _seed(seed), _final_dot_product(final_dot_product)
{
}

Expand Down Expand Up @@ -94,33 +88,47 @@ void one_pass_svd_impl::generate_AOmega(const multi_ex& examples, const std::vec
auto p = std::min(num_actions, _d + sampling_slack);
AOmega.resize(num_actions, p);

auto calculate_aomega_row = [](uint64_t row_index, uint64_t p, VW::workspace* _all, uint64_t _seed, VW::example* ex,
Eigen::MatrixXf& AOmega, const std::vector<float>& shrink_factors) -> void {
auto& red_features = ex->_reduction_features.template get<VW::generated_interactions::reduction_features>();

for (uint64_t col = 0; col < p; ++col)
auto calculate_aomega_row = [](uint64_t row_index_begin, uint64_t row_index_end, uint64_t p, VW::workspace* _all,
uint64_t _seed, const multi_ex& examples, Eigen::MatrixXf& AOmega,
const std::vector<float>& shrink_factors) -> void {
for (auto row_index = row_index_begin; row_index < row_index_end; ++row_index)
{
float final_dot_prod = 0.f;
VW::example* ex = examples[row_index];
auto& red_features = ex->_reduction_features.template get<VW::generated_interactions::reduction_features>();

for (uint64_t col = 0; col < p; ++col)
{
float final_dot_prod = 0.f;

AO_triplet_constructor tc(_all->weights.mask(), row_index, col, _seed, final_dot_prod);
AO_triplet_constructor tc(_all->weights.mask(), col, _seed, final_dot_prod);

GD::foreach_feature<AO_triplet_constructor, uint64_t, triplet_construction, dense_parameters>(
_all->weights.dense_weights, _all->ignore_some_linear, _all->ignore_linear,
(red_features.generated_interactions ? *red_features.generated_interactions : *ex->interactions),
(red_features.generated_extent_interactions ? *red_features.generated_extent_interactions
: *ex->extent_interactions),
_all->permutations, *ex, tc, _all->_generate_interactions_object_cache);
GD::foreach_feature<AO_triplet_constructor, uint64_t, triplet_construction, dense_parameters>(
_all->weights.dense_weights, _all->ignore_some_linear, _all->ignore_linear,
(red_features.generated_interactions ? *red_features.generated_interactions : *ex->interactions),
(red_features.generated_extent_interactions ? *red_features.generated_extent_interactions
: *ex->extent_interactions),
_all->permutations, *ex, tc, _all->_generate_interactions_object_cache);

AOmega(row_index, col) = final_dot_prod * shrink_factors[row_index];
AOmega(row_index, col) = final_dot_prod * shrink_factors[row_index];
}
}
};

uint64_t row_index = 0;
for (auto* ex : examples)
if (_block_size == 0)
{
// Compute block_size if not specified.
const size_t num_blocks = std::max(1UL, this->_thread_pool.size());
_block_size = examples.size() / num_blocks; // Evenly split the examples into blocks
}
for (size_t row_index_begin = 0; row_index_begin < examples.size();)
{
_futures.emplace_back(_thread_pool.submit(
calculate_aomega_row, row_index, p, _all, _seed, ex, std::ref(AOmega), std ::ref(shrink_factors)));
row_index++;
size_t row_index_end = row_index_begin + _block_size;
if ((row_index_end + _block_size) > examples.size()) { row_index_end = examples.size(); }

_futures.emplace_back(_thread_pool.submit(calculate_aomega_row, row_index_begin, row_index_end, p, _all, _seed,
std::cref(examples), std::ref(AOmega), std::cref(shrink_factors)));

row_index_begin = row_index_end;
}

for (auto& ft : _futures) { ft.get(); }
Expand All @@ -142,8 +150,9 @@ void one_pass_svd_impl::run(const multi_ex& examples, const std::vector<float>&
}
}

one_pass_svd_impl::one_pass_svd_impl(VW::workspace* all, uint64_t d, uint64_t seed, size_t, size_t thread_pool_size)
: _all(all), _d(d), _seed(seed), _thread_pool(thread_pool_size)
one_pass_svd_impl::one_pass_svd_impl(
VW::workspace* all, uint64_t d, uint64_t seed, size_t, size_t thread_pool_size, size_t block_size)
: _all(all), _d(d), _seed(seed), _thread_pool(thread_pool_size), _block_size(block_size)
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ void one_rank_spanner_state::compute_spanner(
_X_inv.setIdentity(_d, _d);
_log_determinant_factor = 0;

float max_volume;
float max_volume{};
// Compute a basis contained in U.
for (uint64_t i = 0; i < _d; ++i)
{
Expand All @@ -127,7 +127,6 @@ void one_rank_spanner_state::compute_spanner(
// If replacing some row in _X results in larger volume, replace it with the row from U.
for (uint64_t i = 0; i < _d; ++i)
{
float max_volume;
uint64_t U_rid;
Eigen::VectorXf phi = _X_inv.row(i);
find_max_volume(U, phi, max_volume, U_rid);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ void vanilla_rand_svd_impl::run(const multi_ex& examples, const std::vector<floa
}
}

vanilla_rand_svd_impl::vanilla_rand_svd_impl(VW::workspace* all, uint64_t d, uint64_t seed, size_t, size_t)
vanilla_rand_svd_impl::vanilla_rand_svd_impl(VW::workspace* all, uint64_t d, uint64_t seed, size_t, size_t, size_t)
: _all(all), _d(d), _seed(seed)
{
}
Expand Down
12 changes: 8 additions & 4 deletions vowpalwabbit/core/src/reductions/cb/details/large_action_space.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ class vanilla_rand_svd_impl
Eigen::SparseMatrix<float> Y;
Eigen::MatrixXf Z;

vanilla_rand_svd_impl(VW::workspace* all, uint64_t d, uint64_t seed, size_t total_size, size_t thread_pool_size);
vanilla_rand_svd_impl(
VW::workspace* all, uint64_t d, uint64_t seed, size_t total_size, size_t thread_pool_size, size_t block_size);
void run(const multi_ex& examples, const std::vector<float>& shrink_factors, Eigen::MatrixXf& U, Eigen::VectorXf& _S,
Eigen::MatrixXf& _V);
bool generate_Y(const multi_ex& examples, const std::vector<float>& shrink_factors);
Expand All @@ -65,7 +66,8 @@ class model_weight_rand_svd_impl
Eigen::SparseMatrix<float> Y;
Eigen::MatrixXf Z;

model_weight_rand_svd_impl(VW::workspace* all, uint64_t d, uint64_t seed, size_t total_size, size_t thread_pool_size);
model_weight_rand_svd_impl(
VW::workspace* all, uint64_t d, uint64_t seed, size_t total_size, size_t thread_pool_size, size_t block_size);

void run(const multi_ex& examples, const std::vector<float>& shrink_factors, Eigen::MatrixXf& U, Eigen::VectorXf& _S,
Eigen::MatrixXf& _V);
Expand All @@ -87,12 +89,14 @@ class one_pass_svd_impl
uint64_t _d;
uint64_t _seed;
thread_pool _thread_pool;
size_t _block_size;
std::vector<std::future<void>> _futures;
Eigen::JacobiSVD<Eigen::MatrixXf> _svd;

public:
Eigen::MatrixXf AOmega;
one_pass_svd_impl(VW::workspace* all, uint64_t d, uint64_t seed, size_t total_size, size_t thread_pool_size);
one_pass_svd_impl(
VW::workspace* all, uint64_t d, uint64_t seed, size_t total_size, size_t thread_pool_size, size_t block_size);
void run(const multi_ex& examples, const std::vector<float>& shrink_factors, Eigen::MatrixXf& U, Eigen::VectorXf& _S,
Eigen::MatrixXf& _V);
void generate_AOmega(const multi_ex& examples, const std::vector<float>& shrink_factors);
Expand Down Expand Up @@ -184,7 +188,7 @@ class cb_explore_adf_large_action_space

cb_explore_adf_large_action_space(uint64_t d, float gamma_scale, float gamma_exponent, float c,
bool apply_shrink_factor, VW::workspace* all, uint64_t seed, size_t total_size, size_t thread_pool_size,
implementation_type impl_type);
size_t block_size, implementation_type impl_type);

~cb_explore_adf_large_action_space() = default;

Expand Down

0 comments on commit 60b2580

Please sign in to comment.