Skip to content

Commit

Permalink
clean up integration.cc and Query interface
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 11, 2022
1 parent 3f93a1e commit 99f1701
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 20 deletions.
6 changes: 3 additions & 3 deletions include/tvm/meta_schedule/integration.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class MetaScheduleContextNode : public runtime::Object {
* 3) relay::Function if `mod` should be dispatched to BYOC workflow
* 4) IRModule for unified dispatch
*/
virtual Optional<ObjectRef> Query(runtime::String task_name, IRModule mod, Target target,
virtual IRModule Query(runtime::String task_name, IRModule mod, Target target,
Optional<Array<IRModule>> dispatched) = 0;

static constexpr const char* _type_key = "meta_schedule.MetaScheduleContext";
Expand Down Expand Up @@ -129,7 +129,7 @@ class MetaScheduleContext : public runtime::ObjectRef {
* 3) relay::Function if `mod` should be dispatched to BYOC workflow
* 4) IRModule for unified dispatch
*/
static Optional<ObjectRef> QueryInsideWithScope(runtime::String task_name, IRModule mod,
static IRModule QueryInsideWithScope(runtime::String task_name, IRModule mod,
Target target,
Optional<Array<IRModule>> dispatched);

Expand Down Expand Up @@ -161,7 +161,7 @@ class ApplyHistoryBestNode : public MetaScheduleContextNode {
}

// Inherited from base class
Optional<ObjectRef> Query(runtime::String task_name, IRModule mod, Target target,
IRModule Query(runtime::String task_name, IRModule mod, Target target,
Optional<Array<IRModule>> dispatched) final;

static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest";
Expand Down
32 changes: 24 additions & 8 deletions src/meta_schedule/integration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@ namespace tvm {
namespace meta_schedule {

/**************** Utility functions ****************/
template <class FunctionType>
Optional<GlobalVar> GetOnlyOneFunctionKey(const IRModule& mod) {
if (mod->functions.size() != 1) {
return NullOpt;
}
for (const auto& kv : mod->functions) {
const BaseFunc& func = kv.second;
if (!func->IsInstance<typename FunctionType::ContainerType>()) {
return NullOpt;
} else {
return kv.first;
}
}
return NullOpt;
}

template <class FunctionType>
Optional<FunctionType> GetOnlyOneFunction(const IRModule& mod) {
Expand Down Expand Up @@ -86,12 +101,13 @@ void MetaScheduleContext::ExitWithScope() {
ctx = NullOpt;
}

Optional<ObjectRef> MetaScheduleContext::QueryInsideWithScope(
runtime::String task_name, IRModule mod, Target target, Optional<Array<IRModule>> dispatched) {
IRModule MetaScheduleContext::QueryInsideWithScope(runtime::String task_name, IRModule mod,
Target target,
Optional<Array<IRModule>> dispatched) {
if (Optional<MetaScheduleContext> ctx = MetaScheduleContext::Current()) {
return ctx.value()->Query(task_name, mod, target, dispatched);
}
return NullOpt;
return IRModule{nullptr};
}

/**************** ApplyHistoryBest ****************/
Expand All @@ -102,14 +118,14 @@ ApplyHistoryBest::ApplyHistoryBest(Database database) {
data_ = n;
}

Optional<ObjectRef> ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod,
Target target,
Optional<Array<IRModule>> dispatched) {
IRModule ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, Target target,
Optional<Array<IRModule>> dispatched) {
ICHECK(dispatched.defined());
ICHECK_EQ(dispatched.value().size(), 1);
ICHECK(HasOnlyOneFunction<relay::Function>(mod)) << mod;
IRModule prim_mod = dispatched.value()[0];
ICHECK(HasOnlyOneFunction<tir::PrimFunc>(prim_mod)) << prim_mod;
auto gv = GetOnlyOneFunctionKey<tir::PrimFunc>(prim_mod).value();
// Unify func name to make sure it can be found in database
const auto* parse_mod_func = runtime::Registry::Get("tvm.meta_schedule.tune.parse_mod");
ICHECK(parse_mod_func) << "Parse mod function not defined!";
Expand All @@ -122,11 +138,11 @@ Optional<ObjectRef> ApplyHistoryBestNode::Query(runtime::String task_name, IRMod
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
records[0]->trace->ApplyToSchedule(sch, false);
tir::PrimFunc func = GetOnlyOneFunction<tir::PrimFunc>(sch->mod()).value();
return func;
return IRModule({{gv, func}});
}
}
LOG(WARNING) << "Cannot find workload: " << task_name << "\n" << tir::AsTVMScript(prim_mod);
return NullOpt;
return IRModule{nullptr};
}

/**************** FFI ****************/
Expand Down
18 changes: 9 additions & 9 deletions src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "../../te/operation/create_primfunc.h"
#include "../op/memory/memory.h"
#include "../transforms/pass_utils.h"
#include "tvm/runtime/object.h"
#include "utils.h"

namespace tvm {
Expand Down Expand Up @@ -335,15 +336,14 @@ class ScheduleBuilder : public ExprVisitor {
}
}
if (backend::IsMetaScheduleEnabled()) {
prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs));
Optional<ObjectRef> opt_mod_or_base_func =
meta_schedule::MetaScheduleContext::QueryInsideWithScope(
prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_,
Array<IRModule>{IRModule({{prim_fn_var, prim_func}})});
if (const auto* result = opt_mod_or_base_func.as<tir::PrimFuncNode>()) {
prim_func = GetRef<tir::PrimFunc>(result);
} else {
prim_func = tir::PrimFunc(nullptr);
auto relay_mod = IRModule({{prim_fn_var, relay_func}});
auto tir_mod =
IRModule({{prim_fn_var, tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs))}});
IRModule scheduled_mod = meta_schedule::MetaScheduleContext::QueryInsideWithScope(
prim_fn_var->name_hint, relay_mod, target_, Array<IRModule>{tir_mod});
if (scheduled_mod.defined()) {
ICHECK_EQ(scheduled_mod->functions.count(prim_fn_var), 1);
prim_func = Downcast<tir::PrimFunc>(scheduled_mod->functions[prim_fn_var]);
}
}

Expand Down

0 comments on commit 99f1701

Please sign in to comment.