Skip to content

Commit

Permalink
fixed anchor impl selection
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 11, 2022
1 parent be6c258 commit 6f01901
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor

LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
Array<te::Tensor> outputs = lowered_out->outputs;
anchor_implementation_ = lowered_out->implementation;
op_implementations_[op.operator->()] = lowered_out->implementation;

if (outputs.size() != 1) {
const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>();
Expand Down Expand Up @@ -276,8 +276,8 @@ class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor
Array<tvm::te::Tensor> fn_inputs_;
Array<te::Operation> scalars_;
std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
std::unordered_map<const OpNode*, OpImplementation> op_implementations_;
std::string candidate_name_;
OpImplementation anchor_implementation_;

private:
tvm::Target target_;
Expand All @@ -300,7 +300,6 @@ class ScheduleBuilder : public ExprVisitor {
}

CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
LOG(INFO) << relay_func;
LowerToTECompute lower_te_compute(target_);
Array<te::Tensor> outputs = lower_te_compute.Lower(relay_func, renamer);
Array<te::Tensor> fn_inputs = lower_te_compute.fn_inputs_;
Expand Down Expand Up @@ -350,11 +349,9 @@ class ScheduleBuilder : public ExprVisitor {

// Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
if (!schedule.defined() && !prim_func.defined()) {
ICHECK(lower_te_compute.anchor_implementation_.defined());
LOG(INFO) << lower_te_compute.candidate_name_;
LOG(INFO) << anchor_attrs_;
schedule =
lower_te_compute.anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_);
auto anchor_impl = lower_te_compute.op_implementations_.find(anchor_op_.operator->());
ICHECK(anchor_impl != lower_te_compute.op_implementations_.end());
schedule = anchor_impl->second.Schedule(anchor_attrs_, tensor_outs, target_);
}
if (schedule.defined()) {
for (const auto& scalar : lower_te_compute.scalars_) {
Expand Down

0 comments on commit 6f01901

Please sign in to comment.