Skip to content

Commit

Permalink
Forgot visiting arg in ScheduleBuilder CallNode vsit
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 11, 2022
1 parent 0c6d4a6 commit be6c258
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ class ScheduleBuilder : public ExprVisitor {
}

CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
LOG(INFO) << relay_func;
LowerToTECompute lower_te_compute(target_);
Array<te::Tensor> outputs = lower_te_compute.Lower(relay_func, renamer);
Array<te::Tensor> fn_inputs = lower_te_compute.fn_inputs_;
Expand Down Expand Up @@ -350,6 +351,8 @@ class ScheduleBuilder : public ExprVisitor {
// Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
if (!schedule.defined() && !prim_func.defined()) {
ICHECK(lower_te_compute.anchor_implementation_.defined());
LOG(INFO) << lower_te_compute.candidate_name_;
LOG(INFO) << anchor_attrs_;
schedule =
lower_te_compute.anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_);
}
Expand All @@ -372,6 +375,10 @@ class ScheduleBuilder : public ExprVisitor {
ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);

for (Expr arg : call_node->args) {
VisitExpr(arg);
}

int op_pattern = fpattern[op];
if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
Expand Down

0 comments on commit be6c258

Please sign in to comment.