Skip to content

Commit

Permalink
add doc to util functions
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 11, 2022
1 parent 3c5a318 commit 7b4d35e
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 4 deletions.
4 changes: 3 additions & 1 deletion src/relay/backend/task_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ using meta_schedule::ExtractedTask;
Array<ExtractedTask> ExtractTask(IRModule mod, Target target, Map<String, Constant> params) {
backend::BindParamsInModule(mod, params);

// is_vm=true for backward compatibility
Array<Pass> pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true);
pass_seqs.push_back(transform::FuseOps());

Expand All @@ -51,7 +52,8 @@ Array<ExtractedTask> ExtractTask(IRModule mod, Target target, Map<String, Consta
if (relay_func->HasNonzeroAttr(attr::kPrimitive)) {
Array<te::Tensor> outputs;
std::string fused_name;
std::tie(outputs, fused_name) = tec::LowerTECompute(target, relay_func);
std::tie(outputs, fused_name) =
tec::LowerTECompute(relay_func, target, /*return_inputs*/ true);
auto prim_func = tir::CreatePrimFunc(outputs);
auto prim_fn_var = GlobalVar(fused_name);
auto relay_mod = IRModule({{prim_fn_var, relay_func}});
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -754,10 +754,10 @@ CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target,
return MakeShapeFunc().Create(prim_func, target, renamer);
}

std::pair<Array<te::Tensor>, std::string> LowerTECompute(Target target, const Function& relay_func,
std::pair<Array<te::Tensor>, std::string> LowerTECompute(const Function& source_func, Target target,
bool return_inputs) {
LowerToTECompute lower_te_compute(target);
auto outputs = lower_te_compute.Lower(relay_func, [&](std::string name) { return name; });
auto outputs = lower_te_compute.Lower(source_func, [&](std::string name) { return name; });
// Following ScheduleBuilder, remove placeholder ops from outputs.
tvm::Array<te::Tensor> tensor_outs;
for (const auto& tensor : outputs) {
Expand Down
10 changes: 9 additions & 1 deletion src/relay/backend/te_compiler_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,15 @@ class CCacheValue : public ObjectRef {

Array<IndexExpr> GetShape(const Array<IndexExpr>& shape);

std::pair<Array<te::Tensor>, std::string> LowerTECompute(Target target, const Function& relay_func, bool return_inputs=true);
/*!
* \brief Lowers Relay primitive Function to TE Compute
* \param source_func The primitive function to be lowered.
* \param target The target we want to create schedule for.
* \param return_inputs If true, prepend input tensors to the output array of tensors.
* \return Pair of schedule and fused function name.
*/
std::pair<Array<te::Tensor>, std::string> LowerTECompute(const Function& source_func, Target target,
bool return_inputs = true);

/*!
* \brief Create schedule for target.
Expand Down
5 changes: 5 additions & 0 deletions src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,11 @@ inline std::string DType2String(const tvm::DataType dtype) {
relay::Function BindParamsByName(relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params);

/*!
* \brief Bind params to the main function in Relay module, using BindParamsByName
* \param mod Relay module
* \param params params dict
*/
void BindParamsInModule(IRModule mod,
const std::unordered_map<std::string, runtime::NDArray>& params);

Expand Down

0 comments on commit 7b4d35e

Please sign in to comment.