Skip to content

Commit

Permalink
[ARITH] migrate indexdiv/mod to floordiv/mod (apache#4008)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and wweic committed Sep 30, 2019
1 parent 3efe48e commit 0148904
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 19 deletions.
9 changes: 3 additions & 6 deletions python/tvm/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,13 @@ def __rtruediv__(self, other):
return _generic.divide(other, self)

def __floordiv__(self, other):
# return _generic.floordiv(self, other)
return _generic.divide(self, other)
return _generic.floordiv(self, other)

def __rfloordiv__(self, other):
# return _generic.floordiv(other, self)
return _generic.divide(other, self)
return _generic.floordiv(other, self)

def __mod__(self, other):
raise div_ambiguity_error()
# return _make._OpMod(self, other)
return _make._OpFloorMod(self, other)

def __neg__(self):
neg_one = _api_internal._const(-1, self.dtype)
Expand Down
13 changes: 11 additions & 2 deletions src/lang/attr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -87,6 +87,8 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Div* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::FloorDiv* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::FloorMod* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::GE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
Expand Down Expand Up @@ -119,6 +121,9 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
ATTR_FUNCTOR_DISPATCH(Sub);
ATTR_FUNCTOR_DISPATCH(Mul);
ATTR_FUNCTOR_DISPATCH(Div);
ATTR_FUNCTOR_DISPATCH(Mod);
ATTR_FUNCTOR_DISPATCH(FloorDiv);
ATTR_FUNCTOR_DISPATCH(FloorMod);
ATTR_FUNCTOR_DISPATCH(Min);
ATTR_FUNCTOR_DISPATCH(Max);
ATTR_FUNCTOR_DISPATCH(GE);
Expand Down Expand Up @@ -160,6 +165,8 @@ class AttrsEqualHandler :
bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Div* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::FloorDiv* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::FloorMod* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::GE* lhs, const NodeRef& other) final;
Expand Down Expand Up @@ -201,6 +208,8 @@ class AttrsHashHandler :
size_t VisitAttr_(const ir::Mul* op) final;
size_t VisitAttr_(const ir::Div* op) final;
size_t VisitAttr_(const ir::Mod* op) final;
size_t VisitAttr_(const ir::FloorDiv* op) final;
size_t VisitAttr_(const ir::FloorMod* op) final;
size_t VisitAttr_(const ir::Min* op) final;
size_t VisitAttr_(const ir::Max* op) final;
size_t VisitAttr_(const ir::GE* op) final;
Expand Down
8 changes: 6 additions & 2 deletions src/lang/attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -154,6 +154,8 @@ TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Div);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod);
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDiv);
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorMod);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Max);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Min);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GE);
Expand Down Expand Up @@ -266,6 +268,8 @@ TVM_DEFINE_ATTRS_BINOP_HASH(Sub);
TVM_DEFINE_ATTRS_BINOP_HASH(Mul);
TVM_DEFINE_ATTRS_BINOP_HASH(Div);
TVM_DEFINE_ATTRS_BINOP_HASH(Mod);
TVM_DEFINE_ATTRS_BINOP_HASH(FloorDiv);
TVM_DEFINE_ATTRS_BINOP_HASH(FloorMod);
TVM_DEFINE_ATTRS_BINOP_HASH(Max);
TVM_DEFINE_ATTRS_BINOP_HASH(Min);
TVM_DEFINE_ATTRS_BINOP_HASH(GE);
Expand Down
4 changes: 2 additions & 2 deletions src/lang/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
namespace tvm {

// TODO(tqchen): change to floormod/div
using IndexMod = ir::Mod;
using IndexDiv = ir::Div;
using IndexMod = ir::FloorMod;
using IndexDiv = ir::FloorDiv;

Array<Expr> SimplifyArray(Array<Expr> array) {
for (size_t i = 0; i < array.size(); ++i) {
Expand Down
4 changes: 2 additions & 2 deletions src/lang/expr_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,11 @@ Expr operator%(Expr a, Expr b) {

// TODO(tqchen): switch to floordiv
Expr indexdiv(Expr a, Expr b) {
return truncdiv(a, b);
return floordiv(a, b);
}

Expr indexmod(Expr a, Expr b) {
return truncmod(a, b);
return floormod(a, b);
}

Expr floordiv(Expr a, Expr b) {
Expand Down
14 changes: 10 additions & 4 deletions src/pass/lower_intrin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
patterns_.push_back("tvm.intrin.rule." + starget + ".");
patterns_.push_back("tvm.intrin.rule.default.");
fma_ = runtime::Registry::Get(patterns_[0] + "fma");
if (target == "stackvm") {
support_bitwise_op_ = false;
}
}

Expr Mutate_(const Call* op, const Expr& e) final {
Expand Down Expand Up @@ -76,7 +79,8 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
const DataType& dtype = op->type;
CHECK(dtype.is_int() || !dtype.is_uint());

if (is_const_power_of_two_integer(op->b, &shift)) {
if (support_bitwise_op_ &&
is_const_power_of_two_integer(op->b, &shift)) {
// lower to right shift if possible.
return op->a >> make_const(dtype, shift);
}
Expand All @@ -93,7 +97,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
// condition on b >= 0.
// truncmod(a, b) < 0 will implies ceildiv,
// So we need to correct these cases.
if (dtype == Int(32) || dtype == Int(64)) {
if ((dtype == Int(32) || dtype == Int(64)) && support_bitwise_op_) {
// equivalent to rdiv + (rmod >= 0 ? 0: -1);
return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1));
} else {
Expand Down Expand Up @@ -122,7 +126,8 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
const DataType& dtype = op->type;
CHECK(dtype.is_int() || !dtype.is_uint());

if (is_const_power_of_two_integer(op->b, &shift)) {
if (support_bitwise_op_ &&
is_const_power_of_two_integer(op->b, &shift)) {
// lower to masking if possible.
int64_t mask = (
static_cast<int64_t>(1) << static_cast<int64_t>(shift)) - 1;
Expand All @@ -140,7 +145,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
// mod(a, b) < 0 will imply we are doing ceildiv,
// So we need to correct these cases.
Expr rmod = truncmod(op->a, op->b);
if (dtype == Int(32) || dtype == Int(64)) {
if ((dtype == Int(32) || dtype == Int(64)) && support_bitwise_op_) {
// (rmod >> shift) & b
// -> (rmod >= 0 ? 0: -1) & b
// -> rmod >= 0 ? 0 : b
Expand Down Expand Up @@ -268,6 +273,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
// patterns
std::vector<std::string> patterns_;
const PackedFunc* fma_{nullptr};
bool support_bitwise_op_{true};
};

Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
Expand Down
2 changes: 2 additions & 0 deletions tests/python/unittest/test_codegen_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def test_add_pipeline():
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True)
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
# lower the floordiv(use stackvm rules so it works for all targets)
fsplits = [tvm.ir_pass.LowerIntrin(x, "stackvm") for x in fsplits]
fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])

def check_target(device, host="stackvm"):
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_codegen_vm_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def tvm_call_back_get_shape(shape0):
stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
fapi = tvm.ir_pass.LowerIntrin(fapi, "stackvm")
run_jit(fapi, lambda f: f(a))


Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def get_valid_counts_scan(data, partial_in, partial):
ib.scope_attr(bx, "thread_extent", nthread_bx)
var = tvm.make.node("FloatImm", dtype="float32", value=2)
new_range = num_anchors // elem_per_thread + 1
iteration = log(cast(new_range, "float32")) // math.log(2)
iteration = cast(log(cast(new_range, "float32")) / math.log(2), "int32")
# Scan: Kogge-Stone adder
with ib.if_scope(tvm.all(bx < batch_size, tx < tvm.min(new_range, num_anchors))):
with ib.for_range(0, iteration) as k:
Expand Down

0 comments on commit 0148904

Please sign in to comment.