forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add strategy module (PaddlePaddle#167)
* add primitive layer. add function primitive::add * optimize the primitive::add function * change the name and struct of add function * change opfunction to pe. * add op module. test ci * add class graph and class node. early version * add pass module * add example pass and unittest * add strategy module * fix function name problem
- Loading branch information
Showing
8 changed files
with
199 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
#pragma once | ||
#include <memory> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
#include "cinn/hlir/framework/schedule.h" | ||
#include "cinn/ir/packed_func.h" | ||
|
||
using CINNCompute = cinn::ir::PackedFunc; | ||
using CINNSchedule = cinn::ir::PackedFunc; | ||
|
||
namespace cinn { | ||
namespace hlir { | ||
namespace framework { | ||
|
||
//! Operator implementation that includes compute and schedule function. | ||
class OpImpl : public common::Object { | ||
public: | ||
//! Compute function | ||
CINNCompute fcompute; | ||
//! Schedule function | ||
CINNSchedule fschedule; | ||
//! Name of the implementation | ||
std::string name; | ||
//! Priority level | ||
int plevel; | ||
/** | ||
* \brief Invoke the operator compute function. | ||
* @param attrs The attribute of the primitive | ||
* @param inputs The input tensors. | ||
* @param out_type The output type information. | ||
* @return The output compute description of the operator. | ||
*/ | ||
ir::Tensor Compute(const std::vector<ir::Tensor>& inputs, const Type& out_type) { | ||
// TODO(haozech) : add support for packedfunc to return Tensor | ||
// Expected : return this->fcompute(inputs, out_type); | ||
ir::Tensor temp; | ||
return temp; | ||
} | ||
/** | ||
* \brief Build the computation schedule. | ||
* @param attrs The attribute of the node. | ||
* @param outs The output tensors. | ||
* @param target The build target. | ||
* @return The computation schedule. | ||
*/ | ||
common::Shared<Schedule> GetSchedule(const std::vector<ir::Tensor>& outs, | ||
const std::vector<ir::Tensor>& temp_tensors, | ||
const Target& target) { | ||
// TODO(haozech) : add support for packedfunc to return Schedule | ||
// Expected : return this->fschedule(outs, target); | ||
return nullptr; | ||
} | ||
|
||
const char* type_info() const override { return _type_key; } | ||
|
||
private: | ||
static constexpr char* _type_key = "OpImplementation"; | ||
}; | ||
|
||
//! Specialized implementations for operators under certain conditions. | ||
class OpSpec : public common::Object { | ||
public: | ||
//! List of implementations. | ||
std::vector<OpImpl*> implementations; | ||
|
||
/** \brief Condition to enable the specialization. | ||
* Could be undefined to represent generic case. | ||
* TODO(haozech) : build a specified class SpecializedCondition to represent the condition. | ||
* Expected : SpecializedCondition condition; | ||
*/ | ||
std::string condition; | ||
|
||
const char* type_info() const override { return _type_key; } | ||
|
||
void AddImpl(CINNCompute fcompute, CINNSchedule fschedule, std::string name, int plevel) { | ||
auto n = make_shared<OpImpl>(); | ||
n->fcompute = fcompute; | ||
n->fschedule = fschedule; | ||
n->name = std::move(name); | ||
n->plevel = plevel; | ||
this->implementations.push_back(n); | ||
} | ||
|
||
private: | ||
static constexpr char* _type_key = "OpSpecialization"; | ||
}; | ||
|
||
//! Operator strategy class. | ||
class OpStrategy : public common::Object { | ||
public: | ||
const char* type_info() const override { return "CINNOpStrategy"; } | ||
//! List of operator specializations. | ||
std::vector<OpSpec*> specializations; | ||
|
||
/** | ||
* \brief Add an implementation. | ||
* @param fcompute Compute function | ||
* @param fschedule Schedule function | ||
* @param name Name of the implementation | ||
* @param plevel Priority level of the implementation | ||
*/ | ||
void AddImpl(CINNCompute fcompute, CINNSchedule fschedule, std::string name, int plevel) { | ||
//! TODO(haozech) : here curr_cond should get the condition from outside. | ||
//! Expected : auto curr_cond = SpecializedCondition::Current(); | ||
std::string curr_cond = "current_condition"; | ||
OpSpec* op_spec; | ||
for (OpSpec* op_spec : specializations) { | ||
if (op_spec->condition == curr_cond) { | ||
op_spec->AddImpl(fcompute, fschedule, std::move(name), plevel); | ||
return; | ||
} | ||
} | ||
OpSpec* n = make_shared<OpSpec>(); | ||
n->condition = curr_cond; | ||
n->AddImpl(fcompute, fschedule, std::move(name), plevel); | ||
this->specializations.push_back(n); | ||
} | ||
}; | ||
|
||
} // namespace framework | ||
} // namespace hlir | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#pragma once | ||
#include <string> | ||
#include <unordered_map> | ||
#include <vector> | ||
#include "cinn/cinn.h" | ||
#include "cinn/lang/tensor.h" | ||
namespace cinn { | ||
namespace hlir { | ||
namespace framework { | ||
/** | ||
* \brief Global schedule container | ||
* For operations and all the operations they depend on. | ||
* The schedule per Operation is named as stage. | ||
*/ | ||
class Schedule : public common::Object { | ||
public: | ||
const char* type_info() const override { return "CINNSchedule"; } | ||
|
||
/** | ||
* \brief Get the stage corresponds to the op | ||
* @param op The operation. | ||
*/ | ||
ir::Tensor operator[](const ir::Operation& op) { | ||
auto it = stage_map.find(op.name); | ||
CHECK(it != stage_map.end()) << "Cannot find Stage for operator " << op.name << " in the schedule"; | ||
return it->second; | ||
} | ||
|
||
//! The output operations in original data flow graph | ||
std::vector<ir::Operation> outputs; | ||
/** | ||
* \brief list of all stages for ops. | ||
* The stages are sorted in dependency order. | ||
*/ | ||
std::vector<poly::Stage> stages; | ||
|
||
//! map of original operation to the stages | ||
std::unordered_map<std::string, ir::Tensor> stage_map; | ||
}; | ||
} // namespace framework | ||
} // namespace hlir | ||
} // namespace cinn |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters