Skip to content

Commit

Permalink
refactor: migrate cs_active finish function (#4394)
Browse files Browse the repository at this point in the history
* refactor: migrate cs_active finish function

* format

* fix copy
  • Loading branch information
jackgerrits authored Dec 28, 2022
1 parent 7c1d200 commit 863f04f
Showing 1 changed file with 73 additions and 4 deletions.
77 changes: 73 additions & 4 deletions vowpalwabbit/core/src/reductions/cs_active.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
#include "vw/common/vw_exception.h"
#include "vw/config/options.h"
#include "vw/core/debug_log.h"
#include "vw/core/example.h"
#include "vw/core/learner.h"
#include "vw/core/loss_functions.h"
#include "vw/core/named_labels.h"
#include "vw/core/rand48.h"
#include "vw/core/reductions/csoaa.h"
#include "vw/core/setup_base.h"
Expand Down Expand Up @@ -302,12 +304,77 @@ void predict_or_learn(cs_active& cs_a, single_learner& base, VW::example& ec)
ec.l.cs = ld;
}

void finish_example(VW::workspace& all, cs_active&, VW::example& ec)
void update_stats_cs_active(const VW::workspace& /* all */, VW::shared_data& sd, const cs_active& /* data */,
const VW::example& ec, VW::io::logger& logger)
{
VW::details::output_cs_example(all, ec, ec.l.cs, ec.pred.active_multiclass.predicted_class);
VW::finish_example(all, ec);
const auto& label = ec.l.cs;
const auto multiclass_prediction = ec.pred.active_multiclass.predicted_class;
float loss = 0.;
if (!label.is_test_label())
{
// need to compute exact loss
auto pred = static_cast<size_t>(multiclass_prediction);

float chosen_loss = FLT_MAX;
float min = FLT_MAX;
for (const auto& cl : label.costs)
{
if (cl.class_index == pred) { chosen_loss = cl.x; }
if (cl.x < min) { min = cl.x; }
}
if (chosen_loss == FLT_MAX)
{
logger.err_warn("csoaa predicted an invalid class. Are all multi-class labels in the {{1..k}} range?");
}

loss = (chosen_loss - min) * ec.weight;
// TODO(alberto): add option somewhere to allow using absolute loss instead?
// loss = chosen_loss;
}

sd.update(ec.test_only, !label.is_test_label(), loss, ec.weight, ec.get_num_features());
}

void output_example_prediction_cs_active(
VW::workspace& all, const cs_active& /* data */, const VW::example& ec, VW::io::logger& /* unused */)
{
const auto& label = ec.l.cs;
const auto multiclass_prediction = ec.pred.active_multiclass.predicted_class;

for (auto& sink : all.final_prediction_sink)
{
if (!all.sd->ldict)
{
all.print_by_ref(sink.get(), static_cast<float>(multiclass_prediction), 0, ec.tag, all.logger);
}
else
{
VW::string_view sv_pred = all.sd->ldict->get(multiclass_prediction);
all.print_text_by_ref(sink.get(), std::string{sv_pred}, ec.tag, all.logger);
}
}

if (all.raw_prediction != nullptr)
{
std::stringstream output_string_stream;
for (unsigned int i = 0; i < label.costs.size(); i++)
{
const auto& cl = label.costs[i];
if (i > 0) { output_string_stream << ' '; }
output_string_stream << cl.class_index << ':' << cl.partial_prediction;
}
all.print_text_by_ref(all.raw_prediction.get(), output_string_stream.str(), ec.tag, all.logger);
}
}

void print_update_cs_active(VW::workspace& all, VW::shared_data& /* sd */, const cs_active& /* data */,
const VW::example& ec, VW::io::logger& /* unused */)
{
const auto& label = ec.l.cs;
const auto multiclass_prediction = ec.pred.active_multiclass.predicted_class;

VW::details::print_cs_update(all, label.is_test_label(), ec, nullptr, false, multiclass_prediction);
}
} // namespace

base_learner* VW::reductions::cs_active_setup(VW::setup_base_i& stack_builder)
Expand Down Expand Up @@ -396,7 +463,9 @@ base_learner* VW::reductions::cs_active_setup(VW::setup_base_i& stack_builder)
.set_learn_returns_prediction(true)
.set_output_prediction_type(VW::prediction_type_t::ACTIVE_MULTICLASS)
.set_input_label_type(VW::label_type_t::CS)
.set_finish_example(::finish_example)
.set_output_example_prediction(output_example_prediction_cs_active)
.set_print_update(print_update_cs_active)
.set_update_stats(update_stats_cs_active)
.build();

// Label parser set to cost sensitive label parser
Expand Down

0 comments on commit 863f04f

Please sign in to comment.