Skip to content

Commit

Permalink
[SCHEDULE] Fuse support for 0 rank tensor (apache#1328)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and sergei-mironov committed Aug 8, 2018
1 parent 703cb67 commit 8464d12
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 15 deletions.
39 changes: 36 additions & 3 deletions include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,20 @@ class Stage : public NodeRef {
* \return reference to self.
*/
EXPORT Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*)
/*!
* \brief Fuse all the axes together into a single axis.
*
* \param axes All the axes to be fused.
* \param p_target The result target domain.
*
* \note axes can be an empty array,
* in that case, a singleton itervar is created and
* inserted to the outermost loop.
* The fuse of empty array is used to support zero-dimension tensors.
*
* \return reference to self.
*/
EXPORT Stage& fuse(const Array<IterVar>& axes, IterVar* p_target); // NOLINT(*)
/*!
* \brief Reorder the iteration
* \param order The order of iteration variable.
Expand All @@ -151,9 +165,9 @@ class Stage : public NodeRef {
* \return reference to self.
*/
EXPORT Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
Expr x_factor, Expr y_factor,
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner);
Expr x_factor, Expr y_factor,
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner);
/*!
* \brief Vectorize iteration.
* \param var The axis to be vectorized.
Expand Down Expand Up @@ -674,6 +688,25 @@ class RebaseNode : public IterVarRelationNode {
};


/*!
* \brief Singleton iterator [0, 1)
*/
class SingletonNode : public IterVarRelationNode {
public:
/*! \brief The singleton iterator */
IterVar iter;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("iter", &iter);
}

static IterVarRelation make(IterVar iter);

static constexpr const char* _type_key = "Singleton";
TVM_DECLARE_NODE_TYPE_INFO(SingletonNode, IterVarRelationNode);
};


// implementations
inline const StageNode* Stage::operator->() const {
return static_cast<const StageNode*>(node_.get());
Expand Down
11 changes: 7 additions & 4 deletions python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ class Fuse(NodeBase):
pass


@register_node
class Singleton(NodeBase):
"""Singleton axis."""
pass


@register_node
class IterVar(NodeBase, _expr.ExprOp):
"""Represent iteration variable.
Expand Down Expand Up @@ -380,10 +386,7 @@ def fuse(self, *args):
fused : IterVar
The fused variable of iteration.
"""
assert len(args) >= 1, "Length of the arguments must be >=1 for fuse."
fused = args[0]
for i in range(1, len(args)):
fused = _api_internal._StageFuse(self, fused, args[i])
fused = _api_internal._StageFuse(self, args)
return fused

def set_scope(self, scope):
Expand Down
2 changes: 1 addition & 1 deletion src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ TVM_REGISTER_API("_StageFuse")
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar fused;
args[0].operator Stage()
.fuse(args[1], args[2], &fused);
.fuse(args[1], &fused);
*ret = fused;
});

Expand Down
11 changes: 11 additions & 0 deletions src/schedule/message_passing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ void PassDownDomain(const Stage& stage,
Update(p_state, r->rebased,
Range::make_by_min_extent(
0, state.at(r->parent)->extent));
} else if (const SingletonNode* s = rel.as<SingletonNode>()) {
Update(p_state, s->iter, Range::make_by_min_extent(0, 1));
} else {
LOG(FATAL) << "unknown relation type";
}
Expand Down Expand Up @@ -147,6 +149,7 @@ void PassUpIndex(const Stage& stage,
} else {
state[s->parent] = value;
}
} else if (rel.as<SingletonNode>()) {
} else {
LOG(FATAL) << "unknown relation type";
}
Expand Down Expand Up @@ -192,6 +195,8 @@ void PassDownIndex(const Stage& stage,
Expr parent_min = dom_map.at(s->parent)->min;
CHECK(is_zero(parent_min));
state[s->rebased] = value;
} else if (const SingletonNode* s = rel.as<SingletonNode>()) {
state[s->iter] = make_zero(s->iter->var.type());
} else {
LOG(FATAL) << "unknown relation type";
}
Expand Down Expand Up @@ -296,6 +301,7 @@ void PassUpDomain(const Stage& stage,
state.at(r->rebased),
&parent);
state[r->parent] = parent;
} else if (rel.as<SingletonNode>()) {
} else {
LOG(FATAL) << "unknown relation type";
}
Expand Down Expand Up @@ -344,6 +350,7 @@ void PassUpBitMaskOr(const Stage& stage,
} else {
state[s->parent] |= state[s->rebased];
}
} else if (rel.as<SingletonNode>()) {
} else {
LOG(FATAL) << "unknown relation type";
}
Expand Down Expand Up @@ -390,6 +397,8 @@ void PassDownBitMaskOr(const Stage& stage,
} else {
state[s->rebased] |= state.at(s->parent);
}
} else if (const SingletonNode* s = rel.as<SingletonNode>()) {
state[s->iter] = 0;
} else {
LOG(FATAL) << "unknown relation type";
}
Expand Down Expand Up @@ -438,6 +447,8 @@ void PassUpBoundCheck(const Stage& s,
} else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
state[s->parent] = state.at(s->rebased);
} else if (rel.as<SingletonNode>()) {
// nop
} else {
LOG(FATAL) << "unknown relation type";
}
Expand Down
38 changes: 37 additions & 1 deletion src/schedule/schedule_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT
IterVar fused = IterVarNode::make(
Range(), Var(fused_name, outer->var.type()), iter_type);

*p_target = fused;
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();

Expand All @@ -255,6 +254,31 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT
leaf_vars->data.begin() + pos_inner + 1);
leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer,
fused.node_);
*p_target = fused;
return *this;
}

Stage& Stage::fuse(const Array<IterVar>& axes, IterVar* p_target) { // NOLINT(*)
if (axes.size() != 0) {
IterVar fused = axes[0];
for (size_t i = 1; i < axes.size(); ++i) {
this->fuse(fused, axes[i], &fused);
}
*p_target = std::move(fused);
} else {
StageNode* self = operator->();
// special handle fuse empty array.
// insert at the outer most loop
IterVar singleton = IterVarNode::make(
Range::make_by_min_extent(0, 1),
Var("singleton", Int(32)), kDataPar);
self->relations.push_back(SingletonNode::make(singleton));
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
all_vars->data.push_back(singleton.node_);
leaf_vars->data.insert(leaf_vars->data.begin(), singleton.node_);
*p_target = singleton;
}
return *this;
}

Expand Down Expand Up @@ -732,11 +756,18 @@ IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) {
return IterVarRelation(n);
}

IterVarRelation SingletonNode::make(IterVar iter) {
auto n = std::make_shared<SingletonNode>();
n->iter = iter;
return IterVarRelation(n);
}

TVM_REGISTER_NODE_TYPE(StageNode);
TVM_REGISTER_NODE_TYPE(IterVarAttrNode);
TVM_REGISTER_NODE_TYPE(SplitNode);
TVM_REGISTER_NODE_TYPE(FuseNode);
TVM_REGISTER_NODE_TYPE(RebaseNode);
TVM_REGISTER_NODE_TYPE(SingletonNode);
TVM_REGISTER_NODE_TYPE(ScheduleNode);

// Printer
Expand Down Expand Up @@ -778,6 +809,11 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->print(op->rebased);
p->stream << ')';
})
.set_dispatch<SingletonNode>([](const SingletonNode *op, IRPrinter *p) {
p->stream << "singleton(";
p->print(op->iter);
p->stream << ')';
})
.set_dispatch<ScheduleNode>([](const ScheduleNode *op, IRPrinter *p) {
p->stream << "schedule(" << op << ")";
});
Expand Down
9 changes: 4 additions & 5 deletions tests/python/integration/test_ewise.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def test_multiple_cache_write():
n = tvm.convert(1024)
A0 = tvm.placeholder((n,), name='A0', dtype = "float32")
A1 = tvm.placeholder((n,), name='A1', dtype = "float32")
B0, B1 = tvm.compute((n,),
lambda *i: (A0(*i) + A1(*i), A0(*i) * A1(*i)),
B0, B1 = tvm.compute((n,),
lambda *i: (A0(*i) + A1(*i), A0(*i) * A1(*i)),
name='B')
C = tvm.compute((n,), lambda *i: B0(*i) + B1(*i),
C = tvm.compute((n,), lambda *i: B0(*i) + B1(*i),
name='C')
s = tvm.create_schedule(C.op)
# create iter var and assign them tags.
Expand Down Expand Up @@ -76,7 +76,7 @@ def check_device(device, host="stackvm"):
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
func(a0, a1, c)
np.testing.assert_allclose(
c.asnumpy(), a0.asnumpy() + a1.asnumpy() + (a0.asnumpy() * a1.asnumpy()),
c.asnumpy(), a0.asnumpy() + a1.asnumpy() + (a0.asnumpy() * a1.asnumpy()),
rtol=1e-5)

check_device("cuda", "llvm")
Expand Down Expand Up @@ -235,7 +235,6 @@ def check_device(device):
f(a, b)
np.testing.assert_allclose(
b.asnumpy(), a.asnumpy() + 3, rtol=1e-6)

check_device("cuda")


Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_lang_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,19 @@ def test_fuse():
assert any(isinstance(x, tvm.schedule.Fuse) for x in s[T].relations)
assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi)


def test_singleton():
A = tvm.placeholder((), name='A')
T = tvm.compute((), lambda : A() + 1)
s = tvm.create_schedule(T.op)
fused = s[T].fuse()
assert any(isinstance(x, tvm.schedule.Singleton) for x in s[T].relations)
assert tuple(s[T].leaf_iter_vars) == (fused,)
dump = pkl.dumps(s)
s_loaded = pkl.loads(dump)
assert isinstance(s_loaded, tvm.schedule.Schedule)


def test_vectorize():
m = tvm.var('m')
n = tvm.var('n')
Expand Down Expand Up @@ -174,6 +187,7 @@ def intrin_func(ins, outs):


if __name__ == "__main__":
test_singleton()
test_pragma()
test_tensor_intrin()
test_rfactor()
Expand Down
6 changes: 5 additions & 1 deletion topi/tests/python/test_topi_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def test_broadcast_to():
verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32), topi.broadcast_to)

def test_add():
verify_broadcast_binary_ele(
(), (), topi.add, np.add)
verify_broadcast_binary_ele(
(5, 2, 3), (2, 1), topi.add, np.add)

Expand All @@ -113,6 +115,8 @@ def test_multiply():
def test_divide():
verify_broadcast_binary_ele(
None, (10,), topi.divide, np.divide, rhs_min=0.0001)
verify_broadcast_binary_ele(
(), None, topi.divide, np.divide, rhs_min=0.0001)
verify_broadcast_binary_ele(
(2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001)

Expand Down Expand Up @@ -157,10 +161,10 @@ def test_shift():


if __name__ == "__main__":
test_add()
test_shift()
test_cmp()
test_mod()
test_add()
test_subtract()
test_multiply()
test_divide()
Expand Down

0 comments on commit 8464d12

Please sign in to comment.