Skip to content

Commit

Permalink
Add strategy module (PaddlePaddle#167)
Browse files Browse the repository at this point in the history
* add primitive layer. add function primitive::add

* optimize the primitive::add function

* change the name and struct of add function

* change opfunction to pe.

* add op module. test ci

* add class graph and class node. early version

* add pass module

* add example pass and unittest

* add strategy module

* fix function name problem
  • Loading branch information
haozech authored Aug 12, 2020
1 parent 8401b57 commit 83d1754
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 14 deletions.
5 changes: 2 additions & 3 deletions cinn/backends/llvm/simple_jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@

#include <functional>
#include <memory>
#include <mutex> // NOLINT
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

#include "cinn/backends/llvm/codegen_llvm.h"
Expand Down
5 changes: 2 additions & 3 deletions cinn/hlir/framework/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include <vector>

#include "cinn/common/macros.h"
#include "cinn/utils/base.h"
#include "cinn/utils/registry.h"

namespace cinn {
Expand Down Expand Up @@ -127,7 +126,7 @@ class Operator {
static const OpValueType<ValueType>& GetAttr(const std::string& attr_name) {
const std::any* ref = GetAttrMap(attr_name);
if (ref == nullptr) {
// update the attribute map of the key by creating new empty OpMap
//! update the attribute map of the key by creating new empty OpMap
UpdateAttrMap(attr_name, [attr_name](std::any* pmap) {
if (!pmap->has_value()) {
OpValueType<ValueType> pm;
Expand Down Expand Up @@ -155,7 +154,7 @@ class Operator {
return nullptr;
}
}
// update the attribute OpValueType
//! update the attribute OpValueType
static void UpdateAttrMap(const std::string& key, std::function<void(std::any*)> updater) {
OpRegistry* reg = OpRegistry::Global();
std::lock_guard<std::recursive_mutex>(reg->mutex);
Expand Down
123 changes: 123 additions & 0 deletions cinn/hlir/framework/op_strategy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#pragma once
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "cinn/hlir/framework/schedule.h"
#include "cinn/ir/packed_func.h"

using CINNCompute = cinn::ir::PackedFunc;
using CINNSchedule = cinn::ir::PackedFunc;

namespace cinn {
namespace hlir {
namespace framework {

//! Operator implementation that includes compute and schedule function.
class OpImpl : public common::Object {
public:
//! Compute function
CINNCompute fcompute;
//! Schedule function
CINNSchedule fschedule;
//! Name of the implementation
std::string name;
//! Priority level
int plevel;
/**
* \brief Invoke the operator compute function.
* @param attrs The attribute of the primitive
* @param inputs The input tensors.
* @param out_type The output type information.
* @return The output compute description of the operator.
*/
ir::Tensor Compute(const std::vector<ir::Tensor>& inputs, const Type& out_type) {
// TODO(haozech) : add support for packedfunc to return Tensor
// Expected : return this->fcompute(inputs, out_type);
ir::Tensor temp;
return temp;
}
/**
* \brief Build the computation schedule.
* @param attrs The attribute of the node.
* @param outs The output tensors.
* @param target The build target.
* @return The computation schedule.
*/
common::Shared<Schedule> GetSchedule(const std::vector<ir::Tensor>& outs,
const std::vector<ir::Tensor>& temp_tensors,
const Target& target) {
// TODO(haozech) : add support for packedfunc to return Schedule
// Expected : return this->fschedule(outs, target);
return nullptr;
}

const char* type_info() const override { return _type_key; }

private:
static constexpr char* _type_key = "OpImplementation";
};

//! Specialized implementations for operators under certain conditions.
class OpSpec : public common::Object {
public:
//! List of implementations.
std::vector<OpImpl*> implementations;

/** \brief Condition to enable the specialization.
* Could be undefined to represent generic case.
* TODO(haozech) : build a specified class SpecializedCondition to represent the condition.
* Expected : SpecializedCondition condition;
*/
std::string condition;

const char* type_info() const override { return _type_key; }

void AddImpl(CINNCompute fcompute, CINNSchedule fschedule, std::string name, int plevel) {
auto n = make_shared<OpImpl>();
n->fcompute = fcompute;
n->fschedule = fschedule;
n->name = std::move(name);
n->plevel = plevel;
this->implementations.push_back(n);
}

private:
static constexpr char* _type_key = "OpSpecialization";
};

//! Operator strategy class.
class OpStrategy : public common::Object {
public:
const char* type_info() const override { return "CINNOpStrategy"; }
//! List of operator specializations.
std::vector<OpSpec*> specializations;

/**
* \brief Add an implementation.
* @param fcompute Compute function
* @param fschedule Schedule function
* @param name Name of the implementation
* @param plevel Priority level of the implementation
*/
void AddImpl(CINNCompute fcompute, CINNSchedule fschedule, std::string name, int plevel) {
//! TODO(haozech) : here curr_cond should get the condition from outside.
//! Expected : auto curr_cond = SpecializedCondition::Current();
std::string curr_cond = "current_condition";
OpSpec* op_spec;
for (OpSpec* op_spec : specializations) {
if (op_spec->condition == curr_cond) {
op_spec->AddImpl(fcompute, fschedule, std::move(name), plevel);
return;
}
}
OpSpec* n = make_shared<OpSpec>();
n->condition = curr_cond;
n->AddImpl(fcompute, fschedule, std::move(name), plevel);
this->specializations.push_back(n);
}
};

} // namespace framework
} // namespace hlir
} // namespace cinn
29 changes: 27 additions & 2 deletions cinn/hlir/framework/op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

#include "cinn/cinn.h"

#include "cinn/hlir/framework/op_strategy.h"

#include "cinn/hlir/pe/broadcast.h"

namespace cinn {
namespace hlir {
namespace framework {
Expand All @@ -17,9 +21,30 @@ CINN_REGISTER_OP(add)
.set_attr<std::string>("nick_name", "plus")
.set_support_level(4);

common::Shared<OpStrategy> GetStrategyTest() {
ir::PackedFunc::body_t body = [](ir::Args args, ir::RetValue* ret) {
Expr a = args[0];
Expr b = args[1];
(*ret) = Expr(pe::Add(a.as_tensor_ref(), b.as_tensor_ref(), "C").get());
};
ir::PackedFunc fcompute(body);
// TODO(haozech): fschedule should be an instance of pe::schedule...
ir::PackedFunc fschedule;
common::Shared<OpStrategy> strategy(make_shared<OpStrategy>());
//! To build more complex strategy, we can add more than 1
//! implementations to one Opstrategy, with different plevel.
strategy->AddImpl(fcompute, fschedule, "test.strategy", 10);
return strategy;
}

TEST(Operator, GetAttr) {
auto add = Operator::Get("add");
auto nick = Operator::GetAttr<std::string>("nick_name");
auto add = Operator::Get("add");
auto test_strategy = GetStrategyTest();
Operator temp = *add;
temp.set_attr<common::Shared<OpStrategy>>("CINNStrategy", test_strategy);
auto nick = Operator::GetAttr<std::string>("nick_name");
auto strategy = Operator::GetAttr<common::Shared<OpStrategy>>("CINNStrategy");
ASSERT_EQ(strategy[add]->specializations[0]->implementations[0]->name, "test.strategy");
ASSERT_EQ(add->description, "test of op Add");
ASSERT_EQ(nick[add], "plus");
}
Expand Down
2 changes: 0 additions & 2 deletions cinn/hlir/framework/print_graph_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ TEST(Operator, GetAttr) {
ASSERT_EQ(s, "0:elementwise_add(Node_add0)\n1:elementwise_add(Node_add1)\n");
delete g;
delete output1;
delete node1;
delete output0;
delete node0;
}

} // namespace framework
Expand Down
42 changes: 42 additions & 0 deletions cinn/hlir/framework/schedule.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#pragma once
#include <string>
#include <unordered_map>
#include <vector>
#include "cinn/cinn.h"
#include "cinn/lang/tensor.h"
namespace cinn {
namespace hlir {
namespace framework {
/**
* \brief Global schedule container
* For operations and all the operations they depend on.
* The schedule per Operation is named as stage.
*/
class Schedule : public common::Object {
public:
const char* type_info() const override { return "CINNSchedule"; }

/**
* \brief Get the stage corresponds to the op
* @param op The operation.
*/
ir::Tensor operator[](const ir::Operation& op) {
auto it = stage_map.find(op.name);
CHECK(it != stage_map.end()) << "Cannot find Stage for operator " << op.name << " in the schedule";
return it->second;
}

//! The output operations in original data flow graph
std::vector<ir::Operation> outputs;
/**
* \brief list of all stages for ops.
* The stages are sorted in dependency order.
*/
std::vector<poly::Stage> stages;

//! map of original operation to the stages
std::unordered_map<std::string, ir::Tensor> stage_map;
};
} // namespace framework
} // namespace hlir
} // namespace cinn
4 changes: 0 additions & 4 deletions cinn/utils/base.h

This file was deleted.

3 changes: 3 additions & 0 deletions cinn/utils/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,6 @@ class FunctionRegEntryBase {
*/
#define CINN_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \
static EntryType &__make_##EntryTypeName##_##Name##__ = ::Registry<EntryType>::Get()->__REGISTER__(#Name)

#define CINN_STR_CONCAT_(__x, __y) __x##__y
#define CINN_STR_CONCAT(__x, __y) CINN_STR_CONCAT_(__x, __y)

0 comments on commit 83d1754

Please sign in to comment.