Skip to content

Commit

Permalink
[SCHEDULE] Further fix of reduce inline with multiple outputs (#508)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Oct 5, 2017
1 parent f631fb4 commit 2f4a5ad
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 12 deletions.
46 changes: 36 additions & 10 deletions src/op/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)

TVM_REGISTER_NODE_TYPE(ComputeOpNode);

inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
return (a->combiner.same_as(b->combiner)) &&
(a->source.same_as(b->source)) &&
(a->axis.same_as(b->axis)) &&
(a->condition.same_as(b->condition));
}

int ComputeOpNode::num_outputs() const {
return body.size();
}
Expand Down Expand Up @@ -98,13 +105,6 @@ Array<Tensor> compute(Array<Expr> shape,
return outputs;
}

inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
return (a->combiner.same_as(b->combiner)) &&
(a->source.same_as(b->source)) &&
(a->axis.same_as(b->axis)) &&
(a->condition.same_as(b->condition));
}

Operation ComputeOpNode::make(std::string name,
std::string tag,
Array<IterVar> axis,
Expand Down Expand Up @@ -151,9 +151,35 @@ Operation ComputeOpNode::ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
Array<Expr> arr = UpdateArray(this->body, [&rmap] (const Expr& e) {
return op::ReplaceTensor(e, rmap);
});
Array<Expr> arr;
if (this->body[0]->is_type<ir::Reduce>()) {
// Specially handle reduce so the replaced op
// still share all the components
const ir::Reduce* reduce = this->body[0].as<ir::Reduce>();
for (size_t i = 1; i < this->body.size(); ++i) {
const ir::Reduce* reduce_ = this->body[i].as<ir::Reduce>();
CHECK(reduce_);
CHECK(ReduceEqual(reduce_, reduce))
<< "The Reduce inputs of ComputeOp should "
<< "have the same attribute except value_index";
}\
Expr new_reduce = op::ReplaceTensor(this->body[0], rmap);
if (!new_reduce.same_as(this->body[0])) {
const ir::Reduce* r = new_reduce.as<ir::Reduce>();
for (size_t k = 0; k < this->body.size(); ++k) {
std::shared_ptr<ir::Reduce> n = std::make_shared<ir::Reduce>(*r);
n->value_index = static_cast<int>(k);
n->type = r->source[k].type();
arr.push_back(Expr(n));
}
} else {
arr = this->body;
}
} else {
arr = UpdateArray(this->body, [&rmap] (const Expr& e) {
return op::ReplaceTensor(e, rmap);
});
}
if (!arr.same_as(this->body)) {
return ComputeOpNode::make(name, tag, axis, arr);
} else {
Expand Down
1 change: 1 addition & 0 deletions src/op/op_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ class TensorReplacer : public ir::IRMutator {
public:
explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
: vmap_(vmap) {}

Expr Mutate_(const ir::Call* op, const Expr& e) {
if (op->call_type == ir::Call::Halide) {
Tensor t = Operation(op->func.node_).output(op->value_index);
Expand Down
6 changes: 4 additions & 2 deletions tests/python/unittest/test_schedule_schedule_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,18 @@ def argmax_init(idx_typ, val_typ):
m = tvm.var('m')
n = tvm.var('n')
val = tvm.placeholder((m, n), name='val', dtype='float32')
val2 = tvm.compute((m, n), lambda i, j: tvm.exp(val[i, j]), name='val2')
val1 = tvm.compute((m, n), lambda i, j: val[i, j]+1, name='val1')
val2 = tvm.compute((m, n), lambda i, j: tvm.exp(val1[i, j]), name='val2')
k = tvm.reduce_axis((0, n), 'k')
T_idx, T_val = tvm.compute((m, ), lambda i: argmax((k.var, val2[i, k]), axis=k), name='T')
s = tvm.create_schedule(T_idx.op)
s[val2].compute_inline()
s[val1].compute_inline()
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)



def test_auto_inline():
m = tvm.var('m')
n = tvm.var('n')
Expand Down

0 comments on commit 2f4a5ad

Please sign in to comment.