Skip to content

Commit

Permalink
Remove the use of dmlc::Stream and dmlc serializer
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Jul 9, 2021
1 parent 5d7d968 commit ab621a5
Show file tree
Hide file tree
Showing 19 changed files with 342 additions and 237 deletions.
18 changes: 10 additions & 8 deletions include/treelite/annotator.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright (c) 2017-2020 by Contributors
* Copyright (c) 2017-2021 by Contributors
* \file annotator.h
* \author Hyunsu Cho
* \brief Branch annotation tools
Expand All @@ -10,6 +10,8 @@
#include <treelite/tree.h>
#include <treelite/data.h>
#include <vector>
#include <cstdio>
#include <cstdint>

namespace treelite {

Expand All @@ -27,14 +29,14 @@ class BranchAnnotator {
void Annotate(const Model& model, const DMatrix* dmat, int nthread, int verbose);
/*!
* \brief load branch annotation from a JSON file
* \param fi input stream
* \param fp input stream
*/
void Load(dmlc::Stream* fi);
void Load(FILE* fp);
/*!
* \brief save branch annotation to a JSON file
* \param fo output stream
* \param fp output stream
*/
void Save(dmlc::Stream* fo) const;
void Save(FILE* fp) const;
/*!
* \brief fetch branch annotation.
* Usage example:
Expand All @@ -48,12 +50,12 @@ class BranchAnnotator {
* \endcode
* \return branch annotation in 2D vector
*/
inline std::vector<std::vector<size_t>> Get() const {
return counts;
inline std::vector<std::vector<uint64_t>> Get() const {
return counts_;
}

private:
std::vector<std::vector<size_t>> counts;
std::vector<std::vector<uint64_t>> counts_;
};

} // namespace treelite
Expand Down
24 changes: 9 additions & 15 deletions include/treelite/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <algorithm>
#include <map>
#include <memory>
#include <ostream>
#include <string>
#include <vector>
#include <utility>
Expand All @@ -27,14 +28,6 @@

#define TREELITE_MAX_PRED_TRANSFORM_LENGTH 256

/* Foward declarations */
namespace dmlc {

class Stream;
float stof(const std::string& value, std::size_t* pos);

} // namespace dmlc

namespace treelite {

// Represent a frame in the Python buffer protocol (PEP 3118). We use a simplified representation
Expand Down Expand Up @@ -161,7 +154,7 @@ enum class TaskType : uint8_t {
};

/*! \brief Group of parameters that are dependent on the choice of the task type. */
struct TaskParameter {
struct TaskParam {
enum class OutputType : uint8_t { kFloat = 0, kInt = 1 };
/*! \brief The type of output from each leaf node. */
OutputType output_type;
Expand Down Expand Up @@ -190,7 +183,7 @@ struct TaskParameter {
unsigned int leaf_vector_size;
};

static_assert(std::is_pod<TaskParameter>::value, "TaskParameter must be POD type");
static_assert(std::is_pod<TaskParam>::value, "TaskParameter must be POD type");

/*! \brief in-memory representation of a decision tree */
template <typename ThresholdType, typename LeafOutputType>
Expand Down Expand Up @@ -289,6 +282,9 @@ class Tree {
ContiguousArray<uint32_t> matching_categories_;
ContiguousArray<std::size_t> matching_categories_offset_;

template <typename WriterType, typename X, typename Y>
friend void SerializeTreeToJSON(WriterType& writer, const Tree<X, Y>& tree);

// allocate a new node
inline int AllocNode();

Expand Down Expand Up @@ -562,8 +558,6 @@ class Tree {
node.gain_ = gain;
node.gain_present_ = true;
}

void ReferenceSerialize(dmlc::Stream* fo) const;
};

struct ModelParam {
Expand Down Expand Up @@ -656,7 +650,7 @@ class Model {

virtual std::size_t GetNumTree() const = 0;
virtual void SetTreeLimit(std::size_t limit) = 0;
virtual void ReferenceSerialize(dmlc::Stream* fo) const = 0;
virtual void SerializeToJSON(std::ostream& fo) const = 0;

/* In-memory serialization, zero-copy */
std::vector<PyBufferFrame> GetPyBuffer();
Expand All @@ -676,7 +670,7 @@ class Model {
/*! \brief whether to average tree outputs */
bool average_tree_output;
/*! \brief Group of parameters that are specific to the particular task type */
TaskParameter task_param;
TaskParam task_param;
/*! \brief extra parameters */
ModelParam param;

Expand Down Expand Up @@ -712,7 +706,7 @@ class ModelImpl : public Model {
ModelImpl(ModelImpl&&) noexcept = default;
ModelImpl& operator=(ModelImpl&&) noexcept = default;

void ReferenceSerialize(dmlc::Stream* fo) const override;
void SerializeToJSON(std::ostream& fo) const override;
inline std::size_t GetNumTree() const override {
return trees.size();
}
Expand Down
4 changes: 2 additions & 2 deletions include/treelite/tree_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,9 @@ ModelParam::InitAllowUnknown(const Container& kwargs) {
TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1);
this->pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1] = '\0';
} else if (e.first == "sigmoid_alpha") {
this->sigmoid_alpha = dmlc::stof(e.second, nullptr);
this->sigmoid_alpha = std::stof(e.second, nullptr);
} else if (e.first == "global_bias") {
this->global_bias = dmlc::stof(e.second, nullptr);
this->global_bias = std::stof(e.second, nullptr);
}
}
return unknowns;
Expand Down
2 changes: 1 addition & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ target_sources(objtreelite
filesystem.cc
optable.cc
serializer.cc
reference_serializer.cc
json_serializer.cc
${PROJECT_SOURCE_DIR}/include/treelite/annotator.h
${PROJECT_SOURCE_DIR}/include/treelite/base.h
${PROJECT_SOURCE_DIR}/include/treelite/c_api.h
Expand Down
70 changes: 49 additions & 21 deletions src/annotator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
#include <treelite/annotator.h>
#include <treelite/math.h>
#include <treelite/omp.h>
#include <dmlc/json.h>
#include <rapidjson/filereadstream.h>
#include <rapidjson/filewritestream.h>
#include <rapidjson/writer.h>
#include <rapidjson/document.h>
#include <limits>
#include <cstdint>

Expand All @@ -22,7 +25,7 @@ union Entry {

template <typename ElementType, typename ThresholdType, typename LeafOutputType>
void Traverse_(const treelite::Tree<ThresholdType, LeafOutputType>& tree,
const Entry<ElementType>* data, int nid, size_t* out_counts) {
const Entry<ElementType>* data, int nid, uint64_t* out_counts) {
++out_counts[nid];
if (!tree.IsLeaf(nid)) {
const unsigned split_index = tree.SplitIndex(nid);
Expand Down Expand Up @@ -57,15 +60,15 @@ void Traverse_(const treelite::Tree<ThresholdType, LeafOutputType>& tree,

template <typename ElementType, typename ThresholdType, typename LeafOutputType>
void Traverse(const treelite::Tree<ThresholdType, LeafOutputType>& tree,
const Entry<ElementType>* data, size_t* out_counts) {
const Entry<ElementType>* data, uint64_t* out_counts) {
Traverse_(tree, data, 0, out_counts);
}

template <typename ElementType, typename ThresholdType, typename LeafOutputType>
inline void ComputeBranchLoopImpl(
const treelite::ModelImpl<ThresholdType, LeafOutputType>& model,
const treelite::DenseDMatrixImpl<ElementType>* dmat, size_t rbegin, size_t rend, int nthread,
const size_t* count_row_ptr, size_t* counts_tloc) {
const size_t* count_row_ptr, uint64_t* counts_tloc) {
std::vector<Entry<ElementType>> inst(nthread * dmat->num_col, {-1});
const size_t ntree = model.trees.size();
CHECK_LE(rbegin, rend);
Expand Down Expand Up @@ -102,7 +105,7 @@ template <typename ElementType, typename ThresholdType, typename LeafOutputType>
inline void ComputeBranchLoopImpl(
const treelite::ModelImpl<ThresholdType, LeafOutputType>& model,
const treelite::CSRDMatrixImpl<ElementType>* dmat, size_t rbegin, size_t rend, int nthread,
const size_t* count_row_ptr, size_t* counts_tloc) {
const size_t* count_row_ptr, uint64_t* counts_tloc) {
std::vector<Entry<ElementType>> inst(nthread * dmat->num_col, {-1});
const size_t ntree = model.trees.size();
CHECK_LE(rbegin, rend);
Expand Down Expand Up @@ -135,7 +138,7 @@ class ComputeBranchLoopDispatcherWithDenseDMatrix {
inline static void Dispatch(
const treelite::ModelImpl<ThresholdType, LeafOutputType>& model,
const treelite::DMatrix* dmat, size_t rbegin, size_t rend, int nthread,
const size_t* count_row_ptr, size_t* counts_tloc) {
const size_t* count_row_ptr, uint64_t* counts_tloc) {
const auto* dmat_ = static_cast<const treelite::DenseDMatrixImpl<ElementType>*>(dmat);
CHECK(dmat_) << "Dangling data matrix reference detected";
ComputeBranchLoopImpl(model, dmat_, rbegin, rend, nthread, count_row_ptr, counts_tloc);
Expand All @@ -149,7 +152,7 @@ class ComputeBranchLoopDispatcherWithCSRDMatrix {
inline static void Dispatch(
const treelite::ModelImpl<ThresholdType, LeafOutputType>& model,
const treelite::DMatrix* dmat, size_t rbegin, size_t rend, int nthread,
const size_t* count_row_ptr, size_t* counts_tloc) {
const size_t* count_row_ptr, uint64_t* counts_tloc) {
const auto* dmat_ = static_cast<const treelite::CSRDMatrixImpl<ElementType>*>(dmat);
CHECK(dmat_) << "Dangling data matrix reference detected";
ComputeBranchLoopImpl(model, dmat_, rbegin, rend, nthread, count_row_ptr, counts_tloc);
Expand All @@ -160,7 +163,7 @@ template <typename ThresholdType, typename LeafOutputType>
inline void ComputeBranchLoop(const treelite::ModelImpl<ThresholdType, LeafOutputType>& model,
const treelite::DMatrix* dmat, size_t rbegin,
size_t rend, int nthread, const size_t* count_row_ptr,
size_t* counts_tloc) {
uint64_t* counts_tloc) {
switch (dmat->GetType()) {
case treelite::DMatrixType::kDense: {
treelite::DispatchWithTypeInfo<ComputeBranchLoopDispatcherWithDenseDMatrix>(
Expand Down Expand Up @@ -188,9 +191,9 @@ inline void
AnnotateImpl(
const treelite::ModelImpl<ThresholdType, LeafOutputType>& model,
const treelite::DMatrix* dmat, int nthread, int verbose,
std::vector<std::vector<size_t>>* out_counts) {
std::vector<size_t> new_counts;
std::vector<size_t> counts_tloc;
std::vector<std::vector<uint64_t>>* out_counts) {
std::vector<uint64_t> new_counts;
std::vector<uint64_t> counts_tloc;
std::vector<size_t> count_row_ptr;

count_row_ptr = {0};
Expand Down Expand Up @@ -223,7 +226,7 @@ AnnotateImpl(
}

// change layout of counts
std::vector<std::vector<size_t>>& counts = *out_counts;
std::vector<std::vector<uint64_t>>& counts = *out_counts;
for (size_t i = 0; i < ntree; ++i) {
counts.emplace_back(&new_counts[count_row_ptr[i]], &new_counts[count_row_ptr[i + 1]]);
}
Expand All @@ -233,22 +236,47 @@ void
BranchAnnotator::Annotate(const Model& model, const DMatrix* dmat, int nthread, int verbose) {
TypeInfo threshold_type = model.GetThresholdType();
model.Dispatch([this, dmat, nthread, verbose, threshold_type](auto& handle) {
AnnotateImpl(handle, dmat, nthread, verbose, &this->counts);
AnnotateImpl(handle, dmat, nthread, verbose, &this->counts_);
});
}

void
BranchAnnotator::Load(dmlc::Stream* fi) {
dmlc::istream is(fi);
std::unique_ptr<dmlc::JSONReader> reader(new dmlc::JSONReader(&is));
reader->Read(&counts);
BranchAnnotator::Load(FILE* fp) {
CHECK(fp) << "Invalid file stream";
char read_buffer[65536];
rapidjson::FileReadStream is(fp, read_buffer, sizeof(read_buffer));

rapidjson::Document doc;
doc.ParseStream(is);

std::string err_msg = "JSON file must contain a list of lists of integers";
CHECK(doc.IsArray()) << err_msg;
counts_.clear();
for (const auto& node_cnt : doc.GetArray()) {
CHECK(node_cnt.IsArray()) << err_msg;
counts_.emplace_back();
for (const auto& e : node_cnt.GetArray()) {
counts_.back().push_back(e.GetUint64());
}
}
}

void
BranchAnnotator::Save(dmlc::Stream* fo) const {
dmlc::ostream os(fo);
std::unique_ptr<dmlc::JSONWriter> writer(new dmlc::JSONWriter(&os));
writer->Write(counts);
BranchAnnotator::Save(FILE* fp) const {
CHECK(fp) << "Invalid file stream";
char write_buffer[65536];
rapidjson::FileWriteStream os(fp, write_buffer, sizeof(write_buffer));
rapidjson::Writer<rapidjson::FileWriteStream> writer(os);

writer.StartArray();
for (const auto& node_cnt : counts_) {
writer.StartArray();
for (auto e : node_cnt) {
writer.Uint64(e);
}
writer.EndArray();
}
writer.EndArray();
}

} // namespace treelite
6 changes: 4 additions & 2 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <dmlc/thread_local.h>
#include <memory>
#include <algorithm>
#include <cstdio>

using namespace treelite;

Expand Down Expand Up @@ -52,8 +53,9 @@ int TreeliteAnnotationSave(AnnotationHandle handle,
const char* path) {
API_BEGIN();
const BranchAnnotator* annotator = static_cast<BranchAnnotator*>(handle);
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(path, "w"));
annotator->Save(fo.get());
FILE* fo = std::fopen(path, "w");
annotator->Save(fo);
std::fclose(fo);
API_END();
}

Expand Down
5 changes: 3 additions & 2 deletions src/compiler/ast/ast.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright (c) 2017-2020 by Contributors
* Copyright (c) 2017-2021 by Contributors
* \file ast.h
* \brief Definition for AST classes
* \author Hyunsu Cho
Expand All @@ -14,6 +14,7 @@
#include <string>
#include <vector>
#include <utility>
#include <cstdint>

namespace treelite {
namespace compiler {
Expand All @@ -24,7 +25,7 @@ class ASTNode {
std::vector<ASTNode*> children;
int node_id;
int tree_id;
dmlc::optional<size_t> data_count;
dmlc::optional<uint64_t> data_count;
dmlc::optional<double> sum_hess;
virtual std::string GetDump() const = 0;
virtual ~ASTNode() = 0; // force ASTNode to be abstract class
Expand Down
3 changes: 2 additions & 1 deletion src/compiler/ast/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <ostream>
#include <utility>
#include <memory>
#include <cstdint>
#include "./ast.h"

namespace treelite {
Expand Down Expand Up @@ -58,7 +59,7 @@ class ASTBuilder {
/* \brief replace split thresholds with integers */
void QuantizeThresholds();
/* \brief Load data counts from annotation file */
void LoadDataCounts(const std::vector<std::vector<size_t>>& counts);
void LoadDataCounts(const std::vector<std::vector<uint64_t>>& counts);
/*
* \brief Get a text representation of AST
*/
Expand Down
Loading

0 comments on commit ab621a5

Please sign in to comment.