Skip to content

Commit

Permalink
[MetaSchedule] Apply-History-Best Task Filtering
Browse files Browse the repository at this point in the history
This PR enables task filtering in Apply-History-Best, which is used in
Relay/Relax integration. Previously, even though a task is ruled out
during task extraction, it still shows up in Relay compilation due to
the lack of filtering on `Apply-History-Best`. However, TE-to-TIR
conversion `te.CreatePrimFunc` doesn't support all cases with hybrid
operators involved, which leads to post-tuning failure affecting
multiple models.
  • Loading branch information
junrushao committed Jun 13, 2022
1 parent 85a190a commit 630c605
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 56 deletions.
21 changes: 19 additions & 2 deletions include/tvm/meta_schedule/apply_history_best.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/target/target.h>

namespace tvm {
namespace te {
class Tensor;
} // namespace te
} // namespace tvm

namespace tvm {
namespace meta_schedule {

Expand All @@ -38,12 +44,21 @@ namespace meta_schedule {
*/
class ApplyHistoryBestNode : public runtime::Object {
public:
using FTEFilterFunc =
runtime::TypedPackedFunc<Optional<tir::PrimFunc>(const Array<te::Tensor, void>&)>;

/*! \brief The database to be queried from */
Database database{nullptr};
/*! \brief The filtering function for TE computation */
FTEFilterFunc te_filter_func{nullptr};
/*! \brief The logging function to be used */
PackedFunc logging_func;

void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("database", &database); }
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("database", &database);
// `te_filter_func` is not visited
// `logging_func` is not visited
}
/*!
* \brief Query the best entry from the database
* \param task_name The name of the task to be queried
Expand All @@ -67,9 +82,11 @@ class ApplyHistoryBest : public runtime::ObjectRef {
/*!
* \brief Constructor
* \param database The database to be queried from
* \param te_filter_func The filtering function for TE computation
* \param logging_func The logging function to use
*/
explicit ApplyHistoryBest(Database database, PackedFunc logging_func);
explicit ApplyHistoryBest(Database database, ApplyHistoryBestNode::FTEFilterFunc te_filter_func,
PackedFunc logging_func);
/*!
* \brief The current ApplyHistoryBest in the context
* \return The ApplyHistoryBest in the current scope.
Expand Down
11 changes: 11 additions & 0 deletions include/tvm/meta_schedule/extracted_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@
#include <tvm/runtime/object.h>
#include <tvm/target/target.h>

namespace tvm {
namespace tir {
class PrimFunc;
} // namespace tir
namespace te {
class Tensor;
} // namespace te
} // namespace tvm

namespace tvm {
namespace meta_schedule {

Expand Down Expand Up @@ -67,6 +76,8 @@ class ExtractedTask : public runtime::ObjectRef {
ExtractedTaskNode);
};

Optional<tvm::tir::PrimFunc> DefaultTaskFilter(const Array<tvm::te::Tensor, void>& args);

} // namespace meta_schedule
} // namespace tvm

Expand Down
17 changes: 14 additions & 3 deletions python/tvm/meta_schedule/apply_history_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
# under the License.
"""A context manager that injects the best tuning record in the database into compilation"""
import logging
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union

from tvm._ffi import register_object
from tvm.ir import IRModule
from tvm.runtime import Object
from tvm.target import Target
from tvm.te import Tensor
from tvm.tir import PrimFunc

from . import _ffi_api
from .database import Database
Expand All @@ -38,13 +40,22 @@ class ApplyHistoryBest(Object):
----------
database : Database
The database to be queried from
te_filter_func : Optional[Callable[[List[Tensor]], PrimFunc]] = None
The filtering function for TE computation
"""

database: Database

def __init__(self, database: Database) -> None:
def __init__(
self,
database: Database,
te_filter_func: Optional[Callable[[List[Tensor]], PrimFunc]] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ApplyHistoryBest, database, make_logging_func(logger) # type: ignore # pylint: disable=no-member
_ffi_api.ApplyHistoryBest, # type: ignore # pylint: disable=no-member
database,
te_filter_func,
make_logging_func(logger),
)

def query(
Expand Down
15 changes: 12 additions & 3 deletions src/meta_schedule/apply_history_best.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/te/tensor.h>

#include "./utils.h"

namespace tvm {
Expand Down Expand Up @@ -87,10 +89,16 @@ void ApplyHistoryBest::ExitWithScope() {

/**************** ApplyHistoryBest ****************/

ApplyHistoryBest::ApplyHistoryBest(Database database, PackedFunc logging_func) {
ApplyHistoryBest::ApplyHistoryBest(Database database,
ApplyHistoryBestNode::FTEFilterFunc te_filter_func,
PackedFunc logging_func) {
ObjectPtr<ApplyHistoryBestNode> n = make_object<ApplyHistoryBestNode>();
n->database = database;
n->te_filter_func = te_filter_func;
n->logging_func = logging_func;
if (te_filter_func == nullptr) {
n->te_filter_func = DefaultTaskFilter;
}
data_ = n;
}

Expand Down Expand Up @@ -129,8 +137,9 @@ Optional<IRModule> ApplyHistoryBestNode::Query(runtime::String task_name, IRModu

TVM_REGISTER_NODE_TYPE(ApplyHistoryBestNode);
TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBest")
.set_body_typed([](Database database, PackedFunc logging_func) -> ApplyHistoryBest {
return ApplyHistoryBest(database, logging_func);
.set_body_typed([](Database database, ApplyHistoryBestNode::FTEFilterFunc te_filter_func,
PackedFunc logging_func) -> ApplyHistoryBest {
return ApplyHistoryBest(database, te_filter_func, logging_func);
});
TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBestEnterScope")
.set_body_typed(ApplyHistoryBestInternal::EnterScope);
Expand Down
42 changes: 42 additions & 0 deletions src/meta_schedule/extracted_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
* under the License.
*/
#include <tvm/meta_schedule/extracted_task.h>
#include <tvm/te/operation.h>
#include <tvm/te/tensor.h>
#include <tvm/tir/function.h>

#include "../te/operation/create_primfunc.h"
#include "./utils.h"

namespace tvm {
namespace meta_schedule {
Expand All @@ -32,6 +38,42 @@ ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target,
data_ = n;
}

Optional<tir::PrimFunc> DefaultTaskFilter(const Array<te::Tensor>& args) {
using namespace ::tvm::te;
std::vector<Tensor> stack;
std::unordered_set<const TensorNode*> visited;
for (const Tensor& v : args) {
for (const PrimExpr& e : v->shape) {
// Dynamic shape is not supported for now
if (!e->IsInstance<IntImmNode>()) {
return NullOpt;
}
}
if (!visited.count(v.get())) {
visited.insert(v.get());
stack.push_back(v);
}
}
while (!stack.empty()) {
Tensor tensor = stack.back();
stack.pop_back();
if (tensor->op->IsInstance<PlaceholderOpNode>()) {
// do nothing
} else if (tensor->op->IsInstance<ComputeOpNode>()) {
Array<Tensor> inputs = tensor->op->InputTensors();
for (const Tensor& v : inputs) {
if (!visited.count(v.get())) {
visited.insert(v.get());
stack.push_back(v);
}
}
} else {
return NullOpt;
}
}
return te::CreatePrimFunc(args);
}

TVM_REGISTER_NODE_TYPE(ExtractedTaskNode);
TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask")
.set_body_typed([](String task_name, IRModule mod, Target target, Array<IRModule> dispatched,
Expand Down
1 change: 1 addition & 0 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/meta_schedule/builder.h>
#include <tvm/meta_schedule/cost_model.h>
#include <tvm/meta_schedule/database.h>
#include <tvm/meta_schedule/extracted_task.h>
#include <tvm/meta_schedule/feature_extractor.h>
#include <tvm/meta_schedule/measure_callback.h>
#include <tvm/meta_schedule/profiler.h>
Expand Down
45 changes: 4 additions & 41 deletions src/relay/backend/task_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,48 +31,12 @@ namespace tvm {
namespace relay {
namespace backend {

bool DefaultTaskFilter(const Array<te::Tensor>& args) {
using namespace ::tvm::te;
std::vector<Tensor> stack;
std::unordered_set<const TensorNode*> visited;
for (const Tensor& v : args) {
for (const PrimExpr& e : v->shape) {
// Dynamic shape is not supported for now
if (!e->IsInstance<IntImmNode>()) {
return false;
}
}
if (!visited.count(v.get())) {
visited.insert(v.get());
stack.push_back(v);
}
}
while (!stack.empty()) {
Tensor tensor = stack.back();
stack.pop_back();
if (tensor->op->IsInstance<PlaceholderOpNode>()) {
// do nothing
} else if (tensor->op->IsInstance<ComputeOpNode>() || tensor->op->IsInstance<ExternOpNode>()) {
Array<Tensor> inputs = tensor->op->InputTensors();
for (const Tensor& v : inputs) {
if (!visited.count(v.get())) {
visited.insert(v.get());
stack.push_back(v);
}
}
} else {
return false;
}
}
return true;
}

Array<meta_schedule::ExtractedTask> ExtractTask(
IRModule mod, Target target, Map<String, runtime::NDArray> params,
runtime::TypedPackedFunc<bool(const Array<te::Tensor>&)> filter_func) {
runtime::TypedPackedFunc<Optional<tir::PrimFunc>(const Array<te::Tensor>&)> filter_func) {
using meta_schedule::ExtractedTask;
if (filter_func == nullptr) {
filter_func = DefaultTaskFilter;
filter_func = tvm::meta_schedule::DefaultTaskFilter;
}
backend::BindParamsInModule(mod, params);
// is_vm=true for backward compatibility
Expand All @@ -98,11 +62,10 @@ Array<meta_schedule::ExtractedTask> ExtractTask(
std::string fused_name;
std::tie(inputs_outputs, fused_name) =
tec::LowerTECompute(relay_func, target, /*return_inputs=*/true);
if (filter_func(inputs_outputs)) {
tir::PrimFunc prim_func = tir::CreatePrimFunc(inputs_outputs);
if (Optional<tir::PrimFunc> prim_func = filter_func(inputs_outputs)) {
GlobalVar prim_fn_var(fused_name);
IRModule relay_mod({{prim_fn_var, relay_func}});
IRModule tir_mod({{prim_fn_var, prim_func}});
IRModule tir_mod({{prim_fn_var, prim_func.value()}});
ExtractedTask extracted_task(fused_name, relay_mod, target, {tir_mod}, 1);
tasks.push_back(extracted_task);
cache.emplace(cache_key, extracted_task);
Expand Down
17 changes: 10 additions & 7 deletions src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -346,15 +346,18 @@ class ScheduleBuilder : public ExprVisitor {
}
}
if (meta_schedule_ctx_) {
IRModule relay_mod({{prim_fn_var, relay_func}});
IRModule tir_mod({{prim_fn_var, tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs))}});
if (Optional<IRModule> scheduled_mod = meta_schedule_ctx_.value()->Query(
prim_fn_var->name_hint, relay_mod, target_, Array<IRModule>{tir_mod})) {
ICHECK_EQ(scheduled_mod.value()->functions.count(prim_fn_var), 1);
prim_func = Downcast<tir::PrimFunc>(scheduled_mod.value()->functions[prim_fn_var]);
Array<te::Tensor> te_args = Concat(fn_inputs, tensor_outs);
if (Optional<tir::PrimFunc> tir_func =
meta_schedule_ctx_.value()->te_filter_func(te_args)) {
IRModule relay_mod({{prim_fn_var, relay_func}});
IRModule tir_mod({{prim_fn_var, tir_func.value()}});
if (Optional<IRModule> scheduled_mod = meta_schedule_ctx_.value()->Query(
prim_fn_var->name_hint, relay_mod, target_, Array<IRModule>{tir_mod})) {
ICHECK_EQ(scheduled_mod.value()->functions.count(prim_fn_var), 1);
prim_func = Downcast<tir::PrimFunc>(scheduled_mod.value()->functions[prim_fn_var]);
}
}
}

// Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
if (!schedule.defined() && !prim_func.defined()) {
if (anchor_op_.defined()) {
Expand Down

0 comments on commit 630c605

Please sign in to comment.