From 67567ff8cdd171d31288220bb24dc53c9916edb2 Mon Sep 17 00:00:00 2001 From: pankratz Date: Sun, 24 May 2020 10:07:19 -0600 Subject: [PATCH] Improved uncommon case of floormod and floordiv. Removed dependence on np floor_div and fmod. --- src/tir/transforms/lower_intrin.cc | 68 ++++++++++++++++++------ topi/tests/python/test_topi_broadcast.py | 27 ++++------ 2 files changed, 60 insertions(+), 35 deletions(-) diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 7df8fd257ca58..71e90b993c247 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -102,14 +102,30 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } } } else { - // uncommon case - DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor"; - // b >= 0 => (rmod >=0 ? rdiv : rdiv - 1) - // b < 0 => (rmod <= 0 ? rdiv : rdiv - 1) - PrimExpr rdiv = truncdiv(op->a, op->b); - PrimExpr rmod = truncmod(op->a, op->b); - return tir::SelectNode::make((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv, - rdiv - make_const(dtype, 1)); + if (dtype.bits() <= 32) { + /* NOTE: + This must be restricted to int32 or less since floats can losslessly represent integers + only if the number of bits in the mantissa exceeds the number of bits in the integer. + Therefore a double (53 bit mantissa) for int32, float (24 bit mantissa) for int16, etc. + Since TVM is unaware of a float128 type, int64 is not supported. + */ + + // floor(a / b) + auto fdtype = DataType::Float(dtype.bits() * 2, dtype.lanes()); + auto div = tir::CastNode::make(fdtype, op->a) + / tir::CastNode::make(fdtype, op->b); + auto f = tvm::floor(div); + return tir::CastNode::make(dtype, VisitExpr_(f.as())); + } else { + // uncommon case + DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor"; + // b >= 0 => (rmod >=0 ? rdiv : rdiv - 1) + // b < 0 => (rmod <= 0 ? rdiv : rdiv - 1) + PrimExpr rdiv = truncdiv(op->a, op->b); + PrimExpr rmod = truncmod(op->a, op->b); + return tir::SelectNode::make((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv, + rdiv - make_const(dtype, 1)); + } } } @@ -148,15 +164,33 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } } } else { - // uncommon case - DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident"; - PrimExpr rmod = truncmod(op->a, op->b); - // b > 0 && rmod >= 0 -> rmod - // b > 0 && rmod < 0 -> rmod + b - // b < 0 && rmod < 0 -> rmod - // b < 0 && rmod > 0 -> rmod + b - return tir::SelectNode::make((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, - rmod + op->b); + if (dtype.bits() <= 32) { + /* NOTE: + This must be restricted to int32 or less since floats can losslessly represent integers + only if the number of bits in the mantissa exceeds the number of bits in the integer. + Therefore a double (53 bit mantissa) for int32, float (24 bit mantissa) for int16, etc. + Since there is no float128 type, int64 is not supported. + */ + + // a - floor(a / b) * b + auto fdtype = DataType::Float(dtype.bits() * 2, dtype.lanes()); + auto div = tir::CastNode::make(fdtype, op->a) + / tir::CastNode::make(fdtype, op->b); + auto f = tvm::floor(div); + auto floor_lowered = tir::CastNode::make(dtype, VisitExpr_(f.as())); + + return op->a - (floor_lowered * op->b); + } else { + // uncommon case + DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident"; + PrimExpr rmod = truncmod(op->a, op->b); + // b > 0 && rmod >= 0 -> rmod + // b > 0 && rmod < 0 -> rmod + b + // b < 0 && rmod < 0 -> rmod + // b < 0 && rmod > 0 -> rmod + b + return tir::SelectNode::make((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, + rmod + op->b); + } } } diff --git a/topi/tests/python/test_topi_broadcast.py b/topi/tests/python/test_topi_broadcast.py index 27b66e04e3947..f3e0300a2d817 100644 --- a/topi/tests/python/test_topi_broadcast.py +++ b/topi/tests/python/test_topi_broadcast.py @@ -90,19 +90,6 @@ def check_device(device): rhs_npy, rhs_nd = gen_operand(rhs_shape, rhs_min, rhs_max, ctx) out_npy = fnumpy(lhs_npy, rhs_npy) - if fnumpy == np.floor_divide: - # avoid check too close to X.5 and X.0 - # FIXME: floor_divide(94.90735, 0.6731018) behaves as floor(div(94.90735, 0.6731018)) - # However the result is somehow incorrect - need to further investigate. - # And looks like numpy's floor_div(a,b) is implemented different from floor(div(a,b)) - mask = np.logical_or(np.abs(np.abs(np.fmod(lhs_npy / rhs_npy, 1)) - 0.5) < 1e-6, - np.abs(np.fmod(lhs_npy / rhs_npy, 1)) < 1e-6) - if mask.any(): - lhs_npy = lhs_npy + mask * 1e-3 * rhs_npy - lhs_npy = lhs_npy.astype(dtype) - lhs_nd = tvm.nd.array(lhs_npy, ctx) if lhs_shape is not None else lhs_npy.item() - out_npy = fnumpy(lhs_npy, rhs_npy) - out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx) foo(lhs_nd, rhs_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4) @@ -151,12 +138,14 @@ def test_divide(): (2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001) def test_floor_divide(): + def _canonical_floor_div(a,b): + return np.floor(a / b) verify_broadcast_binary_ele( - None, (10,), topi.floor_divide, np.floor_divide, rhs_min=0.0001) + None, (10,), topi.floor_divide, _canonical_floor_div, rhs_min=0.0001) verify_broadcast_binary_ele( - (), None, topi.floor_divide, np.floor_divide, rhs_min=0.0001) + (), None, topi.floor_divide, _canonical_floor_div, rhs_min=0.0001) verify_broadcast_binary_ele( - (2, 3, 64, 32), (64, 32), topi.floor_divide, np.floor_divide, rhs_min=0.0001) + (2, 3, 64, 32), (64, 32), topi.floor_divide, _canonical_floor_div, rhs_min=0.0001) def test_maximum_minmum(): verify_broadcast_binary_ele( @@ -175,10 +164,12 @@ def test_mod(): (1, 2, 2), (2,), topi.mod, np.mod, lhs_min=0.001, rhs_min=1, dtype="int32") def test_floor_mod(): + def _canonical_floor_mod(a,b): + return a - np.floor(a / b) * b verify_broadcast_binary_ele( - (1, 2, 2), (2,), topi.floor_mod, np.fmod, lhs_min=0.001, rhs_min=1, dtype="int32") + (1, 2, 2), (2,), topi.floor_mod, _canonical_floor_mod, lhs_min=0.001, rhs_min=1, dtype="int32") verify_broadcast_binary_ele( - (3, 4, 5), (3, 4, 5), topi.floor_mod, np.fmod, lhs_min=0.001, rhs_min=1, dtype="float32") + (3, 4, 5), (3, 4, 5), topi.floor_mod, _canonical_floor_mod, lhs_min=0.001, rhs_min=1, dtype="float32") def test_cmp(): # explicit specify the output type