Skip to content

Commit

Permalink
Improved uncommon case of floormod and floordiv. Removed dependence o…
Browse files Browse the repository at this point in the history
…n np floor_div and fmod.
  • Loading branch information
dpankratz committed May 24, 2020
1 parent 0833b07 commit 67567ff
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 35 deletions.
68 changes: 51 additions & 17 deletions src/tir/transforms/lower_intrin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,30 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
}
}
} else {
// uncommon case
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor";
// b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
// b < 0 => (rmod <= 0 ? rdiv : rdiv - 1)
PrimExpr rdiv = truncdiv(op->a, op->b);
PrimExpr rmod = truncmod(op->a, op->b);
return tir::SelectNode::make((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv,
rdiv - make_const(dtype, 1));
if (dtype.bits() <= 32) {
/* NOTE:
This must be restricted to int32 or less since floats can losslessly represent integers
only if the number of bits in the mantissa exceeds the number of bits in the integer.
Therefore a double (53 bit mantissa) for int32, float (24 bit mantissa) for int16, etc.
Since TVM is unaware of a float128 type, int64 is not supported.
*/

// floor(a / b)
auto fdtype = DataType::Float(dtype.bits() * 2, dtype.lanes());
auto div = tir::CastNode::make(fdtype, op->a)
/ tir::CastNode::make(fdtype, op->b);
auto f = tvm::floor(div);
return tir::CastNode::make(dtype, VisitExpr_(f.as<CallNode>()));
} else {
// uncommon case
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor";
// b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
// b < 0 => (rmod <= 0 ? rdiv : rdiv - 1)
PrimExpr rdiv = truncdiv(op->a, op->b);
PrimExpr rmod = truncmod(op->a, op->b);
return tir::SelectNode::make((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv,
rdiv - make_const(dtype, 1));
}
}
}

Expand Down Expand Up @@ -148,15 +164,33 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
}
}
} else {
// uncommon case
DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident";
PrimExpr rmod = truncmod(op->a, op->b);
// b > 0 && rmod >= 0 -> rmod
// b > 0 && rmod < 0 -> rmod + b
// b < 0 && rmod < 0 -> rmod
// b < 0 && rmod > 0 -> rmod + b
return tir::SelectNode::make((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod,
rmod + op->b);
if (dtype.bits() <= 32) {
/* NOTE:
This must be restricted to int32 or less since floats can losslessly represent integers
only if the number of bits in the mantissa exceeds the number of bits in the integer.
Therefore a double (53 bit mantissa) for int32, float (24 bit mantissa) for int16, etc.
Since there is no float128 type, int64 is not supported.
*/

// a - floor(a / b) * b
auto fdtype = DataType::Float(dtype.bits() * 2, dtype.lanes());
auto div = tir::CastNode::make(fdtype, op->a)
/ tir::CastNode::make(fdtype, op->b);
auto f = tvm::floor(div);
auto floor_lowered = tir::CastNode::make(dtype, VisitExpr_(f.as<CallNode>()));

return op->a - (floor_lowered * op->b);
} else {
// uncommon case
DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident";
PrimExpr rmod = truncmod(op->a, op->b);
// b > 0 && rmod >= 0 -> rmod
// b > 0 && rmod < 0 -> rmod + b
// b < 0 && rmod < 0 -> rmod
// b < 0 && rmod > 0 -> rmod + b
return tir::SelectNode::make((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod,
rmod + op->b);
}
}
}

Expand Down
27 changes: 9 additions & 18 deletions topi/tests/python/test_topi_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,6 @@ def check_device(device):
rhs_npy, rhs_nd = gen_operand(rhs_shape, rhs_min, rhs_max, ctx)
out_npy = fnumpy(lhs_npy, rhs_npy)

if fnumpy == np.floor_divide:
# avoid check too close to X.5 and X.0
# FIXME: floor_divide(94.90735, 0.6731018) behaves as floor(div(94.90735, 0.6731018))
# However the result is somehow incorrect - need to further investigate.
# And looks like numpy's floor_div(a,b) is implemented different from floor(div(a,b))
mask = np.logical_or(np.abs(np.abs(np.fmod(lhs_npy / rhs_npy, 1)) - 0.5) < 1e-6,
np.abs(np.fmod(lhs_npy / rhs_npy, 1)) < 1e-6)
if mask.any():
lhs_npy = lhs_npy + mask * 1e-3 * rhs_npy
lhs_npy = lhs_npy.astype(dtype)
lhs_nd = tvm.nd.array(lhs_npy, ctx) if lhs_shape is not None else lhs_npy.item()
out_npy = fnumpy(lhs_npy, rhs_npy)

out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
foo(lhs_nd, rhs_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
Expand Down Expand Up @@ -151,12 +138,14 @@ def test_divide():
(2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001)

def test_floor_divide():
def _canonical_floor_div(a,b):
return np.floor(a / b)
verify_broadcast_binary_ele(
None, (10,), topi.floor_divide, np.floor_divide, rhs_min=0.0001)
None, (10,), topi.floor_divide, _canonical_floor_div, rhs_min=0.0001)
verify_broadcast_binary_ele(
(), None, topi.floor_divide, np.floor_divide, rhs_min=0.0001)
(), None, topi.floor_divide, _canonical_floor_div, rhs_min=0.0001)
verify_broadcast_binary_ele(
(2, 3, 64, 32), (64, 32), topi.floor_divide, np.floor_divide, rhs_min=0.0001)
(2, 3, 64, 32), (64, 32), topi.floor_divide, _canonical_floor_div, rhs_min=0.0001)

def test_maximum_minmum():
verify_broadcast_binary_ele(
Expand All @@ -175,10 +164,12 @@ def test_mod():
(1, 2, 2), (2,), topi.mod, np.mod, lhs_min=0.001, rhs_min=1, dtype="int32")

def test_floor_mod():
def _canonical_floor_mod(a,b):
return a - np.floor(a / b) * b
verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.floor_mod, np.fmod, lhs_min=0.001, rhs_min=1, dtype="int32")
(1, 2, 2), (2,), topi.floor_mod, _canonical_floor_mod, lhs_min=0.001, rhs_min=1, dtype="int32")
verify_broadcast_binary_ele(
(3, 4, 5), (3, 4, 5), topi.floor_mod, np.fmod, lhs_min=0.001, rhs_min=1, dtype="float32")
(3, 4, 5), (3, 4, 5), topi.floor_mod, _canonical_floor_mod, lhs_min=0.001, rhs_min=1, dtype="float32")

def test_cmp():
# explicit specify the output type
Expand Down

0 comments on commit 67567ff

Please sign in to comment.