Skip to content

Commit

Permalink
float64 support in treelite->FIL import and Python layer (#4690)
Browse files Browse the repository at this point in the history
`float64` support in treelite->FIL import and Python layer

Authors:
  - Andy Adinets (https://github.com/canonizer)
  - Levs Dolgovs (https://github.com/levsnv)

Approvers:
  - Philip Hyunsu Cho (https://github.com/hcho3)
  - William Hicks (https://github.com/wphicks)

URL: #4690
  • Loading branch information
canonizer authored Apr 13, 2022
1 parent 5689b96 commit 57124ce
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 129 deletions.
4 changes: 3 additions & 1 deletion cpp/bench/sg/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ class FIL : public RegressionFixture<float> {
.threads_per_tree = 1,
.n_items = 0,
.pforest_shape_str = nullptr};
ML::fil::from_treelite(*handle, &forest, model, &tl_params);
ML::fil::forest_variant forest_variant;
ML::fil::from_treelite(*handle, &forest_variant, model, &tl_params);
forest = std::get<ML::fil::forest_t<float>>(forest_variant);

// only time prediction
this->loopOnState(state, [this]() {
Expand Down
18 changes: 12 additions & 6 deletions cpp/include/cuml/fil/fil.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

#include <stddef.h>

#include <variant> // for std::get<>, std::variant<>

#include <cuml/ensemble/treelite_defs.hpp>

namespace raft {
Expand All @@ -29,10 +31,8 @@ class handle_t;
namespace ML {
namespace fil {

/** @note FIL only supports inference with single precision.
* TODO(canonizer): parameterize the functions and structures by the data type
* and the threshold/weight type.
*/
/** @note FIL supports inference with both single and double precision. However,
the floating-point type used in the data and model must be the same. */

/** Inference algorithm to use. */
enum algo_t {
Expand Down Expand Up @@ -76,6 +76,13 @@ struct forest;
template <typename real_t>
using forest_t = forest<real_t>*;

/** forest32_t and forest64_t are definitions required in Cython */
using forest32_t = forest<float>*;
using forest64_t = forest<double>*;

/** forest_variant is used to get a forest represented with either float or double. */
using forest_variant = std::variant<forest_t<float>, forest_t<double>>;

/** MAX_N_ITEMS determines the maximum allowed value for tl_params::n_items */
constexpr int MAX_N_ITEMS = 4;

Expand Down Expand Up @@ -114,9 +121,8 @@ struct treelite_params_t {
* @param model treelite model used to initialize the forest
* @param tl_params additional parameters for the forest
*/
// TODO (canonizer): use std::variant<forest_t<float> forest_t<double>>* for pforest
void from_treelite(const raft::handle_t& handle,
forest_t<float>* pforest,
forest_variant* pforest,
ModelHandle model,
const treelite_params_t* tl_params);

Expand Down
69 changes: 39 additions & 30 deletions cpp/src/fil/treelite_import.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include <cstddef> // for std::size_t
#include <cstdint> // for uint8_t
#include <iosfwd> // for ios, stringstream
#include <limits> // for std::numeric_limits
#include <stack> // for std::stack
#include <string> // for std::string
#include <type_traits> // for std::is_same
Expand Down Expand Up @@ -223,7 +224,8 @@ cat_sets_owner allocate_cat_sets_owner(const tl::ModelImpl<T, L>& model)
return cat_sets;
}

void adjust_threshold(float* pthreshold, bool* swap_child_nodes, tl::Operator comparison_op)
template <typename real_t>
void adjust_threshold(real_t* pthreshold, bool* swap_child_nodes, tl::Operator comparison_op)
{
// in treelite (take left node if val [op] threshold),
// the meaning of the condition is reversed compared to FIL;
Expand All @@ -237,12 +239,12 @@ void adjust_threshold(float* pthreshold, bool* swap_child_nodes, tl::Operator co
case tl::Operator::kLT: break;
case tl::Operator::kLE:
// x <= y is equivalent to x < y', where y' is the next representable float
*pthreshold = std::nextafterf(*pthreshold, std::numeric_limits<float>::infinity());
*pthreshold = std::nextafterf(*pthreshold, std::numeric_limits<real_t>::infinity());
break;
case tl::Operator::kGT:
// x > y is equivalent to x >= y', where y' is the next representable float
// left and right still need to be swapped
*pthreshold = std::nextafterf(*pthreshold, std::numeric_limits<float>::infinity());
*pthreshold = std::nextafterf(*pthreshold, std::numeric_limits<real_t>::infinity());
case tl::Operator::kGE:
// swap left and right
*swap_child_nodes = !*swap_child_nodes;
Expand Down Expand Up @@ -279,7 +281,7 @@ void tl2fil_leaf_payload(fil_node_t* fil_node,
const tl::Tree<T, L>& tl_tree,
int tl_node_id,
const forest_params_t& forest_params,
std::vector<float>* vector_leaf,
std::vector<typename fil_node_t::real_type>* vector_leaf,
size_t* leaf_counter)
{
auto vec = tl_tree.LeafVector(tl_node_id);
Expand All @@ -301,7 +303,7 @@ void tl2fil_leaf_payload(fil_node_t* fil_node,
}
case leaf_algo_t::FLOAT_UNARY_BINARY:
case leaf_algo_t::GROVE_PER_CLASS:
fil_node->val.f = static_cast<float>(tl_tree.LeafValue(tl_node_id));
fil_node->val.f = static_cast<typename fil_node_t::real_type>(tl_tree.LeafValue(tl_node_id));
ASSERT(!tl_tree.HasLeafVector(tl_node_id),
"some but not all treelite leaves have leaf_vector()");
break;
Expand All @@ -323,14 +325,15 @@ conversion_state<fil_node_t> tl2fil_inner_node(int fil_left_child,
cat_sets_owner* cat_sets,
std::size_t* bit_pool_offset)
{
using real_t = typename fil_node_t::real_type;
int tl_left = tree.LeftChild(tl_node_id), tl_right = tree.RightChild(tl_node_id);
val_t<float> split = {.f = NAN}; // yes there's a default initializer already
val_t<real_t> split = {.f = std::numeric_limits<real_t>::quiet_NaN()};
int feature_id = tree.SplitIndex(tl_node_id);
bool is_categorical = tree.SplitType(tl_node_id) == tl::SplitFeatureType::kCategorical &&
tree.MatchingCategories(tl_node_id).size() > 0;
bool swap_child_nodes = false;
if (tree.SplitType(tl_node_id) == tl::SplitFeatureType::kNumerical) {
split.f = static_cast<float>(tree.Threshold(tl_node_id));
split.f = static_cast<real_t>(tree.Threshold(tl_node_id));
adjust_threshold(&split.f, &swap_child_nodes, tree.ComparisonOp(tl_node_id));
} else if (tree.SplitType(tl_node_id) == tl::SplitFeatureType::kCategorical) {
// for FIL, the list of categories is always for the right child
Expand All @@ -346,14 +349,14 @@ conversion_state<fil_node_t> tl2fil_inner_node(int fil_left_child,
}
} else {
// always branch left in FIL. Already accounted for Treelite branching direction above.
split.f = NAN;
split.f = std::numeric_limits<real_t>::quiet_NaN();
}
} else {
ASSERT(false, "only numerical and categorical split nodes are supported");
}
bool default_left = tree.DefaultLeft(tl_node_id) ^ swap_child_nodes;
fil_node_t node(
val_t<float>{}, split, feature_id, default_left, false, is_categorical, fil_left_child);
val_t<real_t>{}, split, feature_id, default_left, false, is_categorical, fil_left_child);
return conversion_state<fil_node_t>{node, swap_child_nodes};
}

Expand All @@ -363,7 +366,7 @@ int tree2fil(std::vector<fil_node_t>& nodes,
const tl::Tree<T, L>& tree,
std::size_t tree_idx,
const forest_params_t& forest_params,
std::vector<float>* vector_leaf,
std::vector<typename fil_node_t::real_type>* vector_leaf,
std::size_t* leaf_counter,
cat_sets_owner* cat_sets)
{
Expand Down Expand Up @@ -443,10 +446,11 @@ std::stringstream depth_hist_and_max(const tl::ModelImpl<T, L>& model)
forest_shape << "Total: branches: " << total_branches << " leaves: " << total_leaves
<< " nodes: " << total_nodes << endl;
forest_shape << "Avg nodes per tree: " << setprecision(2)
<< total_nodes / (float)hist[0].n_branch_nodes << endl;
<< total_nodes / static_cast<double>(hist[0].n_branch_nodes) << endl;
forest_shape.copyfmt(default_state);
forest_shape << "Leaf depth: min: " << min_leaf_depth << " avg: " << setprecision(2) << fixed
<< leaves_times_depth / (float)total_leaves << " max: " << hist.size() - 1 << endl;
<< leaves_times_depth / static_cast<double>(total_leaves)
<< " max: " << hist.size() - 1 << endl;
forest_shape.copyfmt(default_state);

vector<char> hist_bytes(hist.size() * sizeof(hist[0]));
Expand Down Expand Up @@ -575,9 +579,10 @@ void node_traits<node_t>::check(const treelite::ModelImpl<threshold_t, leaf_t>&

template <typename fil_node_t, typename threshold_t, typename leaf_t>
struct tl2fil_t {
using real_t = typename fil_node_t::real_type;
std::vector<int> roots_;
std::vector<fil_node_t> nodes_;
std::vector<float> vector_leaf_;
std::vector<real_t> vector_leaf_;
forest_params_t params_;
cat_sets_owner cat_sets_;
const tl::ModelImpl<threshold_t, leaf_t>& model_;
Expand Down Expand Up @@ -631,7 +636,7 @@ struct tl2fil_t {
}

/// initializes FIL forest object, to be ready to infer
void init_forest(const raft::handle_t& handle, forest_t<float>* pforest)
void init_forest(const raft::handle_t& handle, forest_t<real_t>* pforest)
{
ML::fil::init(
handle, pforest, cat_sets_.accessor(), vector_leaf_, roots_.data(), nodes_.data(), &params_);
Expand All @@ -646,7 +651,7 @@ struct tl2fil_t {

template <typename fil_node_t, typename threshold_t, typename leaf_t>
void convert(const raft::handle_t& handle,
forest_t<float>* pforest,
forest_t<typename fil_node_t::real_type>* pforest,
const tl::ModelImpl<threshold_t, leaf_t>& model,
const treelite_params_t& tl_params)
{
Expand All @@ -664,24 +669,21 @@ constexpr bool type_supported()

template <typename threshold_t, typename leaf_t>
void from_treelite(const raft::handle_t& handle,
forest_t<float>* pforest,
forest_variant* pforest_variant,
const tl::ModelImpl<threshold_t, leaf_t>& model,
const treelite_params_t* tl_params)
{
// floating-point type used for model representation
using real_t = decltype(threshold_t(0) + leaf_t(0));

// get the pointer to the right forest variant
*pforest_variant = (forest_t<real_t>)nullptr;
forest_t<real_t>* pforest = &std::get<forest_t<real_t>>(*pforest_variant);

// Invariants on threshold and leaf types
static_assert(type_supported<threshold_t>(),
"Model must contain float32 or float64 thresholds for splits");
ASSERT(type_supported<leaf_t>(), "Models with integer leaf output are not yet supported");
// Display appropriate warnings when float64 values are being casted into
// float32, as FIL only supports inferencing with float32 for the time being
if (std::is_same<threshold_t, double>::value || std::is_same<leaf_t, double>::value) {
CUML_LOG_WARN(
"Casting all thresholds and leaf values to float32, as FIL currently "
"doesn't support inferencing models with float64 values. "
"This may lead to predictions with reduced accuracy.");
}
// same as std::common_type: float+double=double, float+int64_t=float
using real_t = decltype(threshold_t(0) + leaf_t(0));

storage_type_t storage_type = tl_params->storage_type;
// build dense trees by default
Expand All @@ -702,18 +704,25 @@ void from_treelite(const raft::handle_t& handle,

switch (storage_type) {
case storage_type_t::DENSE:
convert<dense_node<float>>(handle, pforest, model, *tl_params);
convert<dense_node<real_t>>(handle, pforest, model, *tl_params);
break;
case storage_type_t::SPARSE:
convert<sparse_node16<float>>(handle, pforest, model, *tl_params);
convert<sparse_node16<real_t>>(handle, pforest, model, *tl_params);
break;
case storage_type_t::SPARSE8:
// SPARSE8 is only supported for float32
if constexpr (std::is_same_v<real_t, float>) {
convert<sparse_node8>(handle, pforest, model, *tl_params);
} else {
ASSERT(false, "SPARSE8 is only supported for float32 treelite models");
}
break;
case storage_type_t::SPARSE8: convert<sparse_node8>(handle, pforest, model, *tl_params); break;
default: ASSERT(false, "tl_params->sparse must be one of AUTO, DENSE or SPARSE");
}
}

void from_treelite(const raft::handle_t& handle,
forest_t<float>* pforest,
forest_variant* pforest,
ModelHandle model,
const treelite_params_t* tl_params)
{
Expand Down
Loading

0 comments on commit 57124ce

Please sign in to comment.