Skip to content

Commit

Permalink
[Arith] Use ConstIntBound to remove negative numerator when lowering (#…
Browse files Browse the repository at this point in the history
…13724)

* [Arith] Use ConstIntBound to remove negative numerator when lowering

Negative numerators to modulo/remainder operations are not supported
by the Vulkan API.  While the SPIR-V instructions
[`OpSRem`](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSRem)
and
[`OpSMod`](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSMod)
have identical semantics to `tir::Mod` and `tir::FloorMod`,
respectively, use of either instruction within Vulkan results in
undefined behavior.  From the [Vulkan
spec](https://registry.khronos.org/vulkan/specs/1.3/html/chap37.html#spirvenv-op-prec):

> For the OpSRem and OpSMod instructions, if either operand is
> negative the result is undefined.
>
> Note: While the OpSRem and OpSMod instructions are supported by the
> Vulkan environment, they require non-negative values and thus do not
> enable additional functionality beyond what OpUMod provides.

This issue was first noticed in
#13530, where use of integer
arithmetic resulted in negative numerators.  This hadn't caused issues
previously, because most use of div/mod use a denominator that is a
power of two.  In these cases, `tir.LowerIntrin` implements floordiv
and floormod using only bitwise operations.  When the denominator
isn't a power of two, both `tir::FloorDiv` and `tir::FloorMod` are
implemented in terms of `tir::Mod`, which triggers the undefined
behavior for negative numerators.

This commit alters the lowering of FloorDiv/FloorMod to
TruncDiv/TruncMod, in cases where the denominator is positive, the
numerator is sometimes negative, and the range of the numerator is
known.  In these cases, the FloorDiv/FloorMod is now implemented by
offsetting the numerator such that it is always positive.

* Add check to avoid -INT32_MIN

* Updated to use `tvm::min_value(DataType)`

* Added derivation for floordiv/floormod in terms of truncdiv/trundmod
  • Loading branch information
Lunderberg authored Jan 10, 2023
1 parent 4cb75b9 commit 68c917d
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 25 deletions.
136 changes: 111 additions & 25 deletions src/tir/transforms/lower_intrin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>

#include <limits>
#include <unordered_set>

#include "../../arith/ir_mutator_with_analyzer.h"
Expand Down Expand Up @@ -112,20 +113,63 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
// Common path, positive divisor
if (analyzer_->CanProveGreaterEqual(op->a, 0) || analyzer_->CanProveGreaterEqual(e, 0)) {
return truncdiv(op->a, op->b);
}

// If the numerator's lower bound is known, express the floordiv
// in terms of truncdiv using only positive operands.
arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a);
if (const_int_bound->min_value != arith::ConstIntBound::kNegInf &&
const_int_bound->min_value < 0 &&
const_int_bound->min_value > Downcast<IntImm>(tvm::min_value(op->a->dtype))->value) {
// The goal is to write floordiv(a,b) in terms of truncdiv, without using
// negative operands.
//
// For any integer c
//
// floordiv(a,b) == floordiv(a + b*c - b*c, b)
// == floordiv(a + b*c, b) - c
//
// Choosing `c = ceildiv(-a_min, b)`. This can be rewritten in terms of
// truncdiv as follows.
//
// c == ceildiv(-a_min,b)
// == floordiv(-a_min + (b-1), b)
// == truncdiv(-a_min + (b-1), b)
//
// When substituted into `a + b*c`, this results in a positive argument.
//
// a + b*c
// == a + b*ceildiv(-a_min,b)
// == a - b*floordiv(a_min,b)
// >= a - b*floordiv(a,b)
// == floormod(a, b)
// >= 0
//
// Since the argument is positive, this allows floordiv to be written as
// followed.
//
// floordiv(a,b)
// == floordiv(a + b*c, b) - c
// == truncdiv(a + b*c, b) - c
IntImm min(op->a->dtype, const_int_bound->min_value);
PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b);
PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv);
return truncdiv(offset_numerator, op->b) - ceildiv;
}

DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident";
PrimExpr rdiv = truncdiv(op->a, op->b);
PrimExpr rmod = truncmod(op->a, op->b);
// condition on b >= 0.
// truncmod(a, b) < 0 will implies ceildiv,
// So we need to correct these cases.
if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
// equivalent to rdiv + (rmod >= 0 ? 0: -1);
return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1));
} else {
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident";
PrimExpr rdiv = truncdiv(op->a, op->b);
PrimExpr rmod = truncmod(op->a, op->b);
// condition on b >= 0.
// truncmod(a, b) < 0 will implies ceildiv,
// So we need to correct these cases.
if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
// equivalent to rdiv + (rmod >= 0 ? 0: -1);
return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1));
} else {
return tir::Select(rmod >= 0, rdiv, rdiv - make_const(dtype, 1));
}
return tir::Select(rmod >= 0, rdiv, rdiv - make_const(dtype, 1));
}

} else {
if (dtype.is_float()) {
// floor(a / b)
Expand Down Expand Up @@ -165,21 +209,63 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
// Common pass, positive divisor
if (analyzer_->CanProveGreaterEqual(op->a, 0)) {
return truncmod(op->a, op->b);
}

// If the numerator's lower bound is known, express the floormod
// in terms of truncmod using only positive operands.
arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a);
if (const_int_bound->min_value != arith::ConstIntBound::kNegInf &&
const_int_bound->min_value < 0 &&
const_int_bound->min_value > Downcast<IntImm>(tvm::min_value(op->a->dtype))->value) {
// The goal is to write floormod(a,b) in terms of truncdiv and truncmod,
// without using negative operands.
//
// For any integer c
//
// floormod(a, b) == floormod(a + b*c, b)
//
// Choosing `c = ceildiv(-a_min, b)`. This can be rewritten in terms of
// truncdiv as follows.
//
// c == ceildiv(-a_min,b)
// == floordiv(-a_min + (b-1), b)
// == truncdiv(-a_min + (b-1), b)
//
// When substituted into `a + b*c`, this results in a positive argument.
//
// a + b*c
// == a + b*ceildiv(-a_min,b)
// == a - b*floordiv(a_min,b)
// >= a - b*floordiv(a,b)
// == floormod(a, b)
// >= 0
//
// Since the argument is positive, this allows floordiv to be written as
// followed.
//
// floormod(a,b)
// == floormod(a + b*c, b)
// == truncmod(a + b*c, b)
IntImm min(op->a->dtype, const_int_bound->min_value);
PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b);
PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv);
return truncmod(offset_numerator, op->b);
}

DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident";
// NOTE:condition on b >= 0.
// mod(a, b) < 0 will imply we are doing ceildiv,
// So we need to correct these cases.
PrimExpr rmod = truncmod(op->a, op->b);
if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
// (rmod >> shift) & b
// -> (rmod >= 0 ? 0: -1) & b
// -> rmod >= 0 ? 0 : b
return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1)));
} else {
DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident";
// NOTE:condition on b >= 0.
// mod(a, b) < 0 will imply we are doing ceildiv,
// So we need to correct these cases.
PrimExpr rmod = truncmod(op->a, op->b);
if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
// (rmod >> shift) & b
// -> (rmod >= 0 ? 0: -1) & b
// -> rmod >= 0 ? 0 : b
return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1)));
} else {
return tir::Select(rmod >= 0, rmod, rmod + op->b);
}
return tir::Select(rmod >= 0, rmod, rmod + op->b);
}

} else {
if (dtype.is_float()) {
// a - floor(a / b) * b
Expand Down
42 changes: 42 additions & 0 deletions tests/python/unittest/test_target_codegen_vulkan.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import tvm.testing
from tvm import relay, te
from tvm.topi.math import cast
from tvm.script import tir as T


dtype = tvm.testing.parameter("float32", "int32", "float16", "int8")
Expand Down Expand Up @@ -558,5 +559,46 @@ def do_compute(ins, outs):
tvm.build(s, [Out], target)


def test_negative_operand_divmod(target, dev):
"""Test handling of negative offsets to floormod/floordiv
Even though the SPIR-V spec states that OpSRem and OpSMod can give
the signed modulo, the Vulkan spec states that any use of negative
operands is undefined behavior. This test starts with negative
operands to floordiv, validating that they are simplified into the
corresponding positive operands, such that the final TIR can be
expressed using only positive operands.
SPIR-V: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSRem
Vulkan: https://registry.khronos.org/vulkan/specs/1.3/html/chap37.html#spirvenv-op-prec
"""

N = 32
offset = 16
divisor = 5

@T.prim_func
def func(A: T.Buffer[(N, 2), "int32"]):
for i in T.serial(N):
with T.block("A"):
v_i = T.axis.spatial(N, i)
A[v_i, 0] = T.floordiv(v_i - offset, divisor)
A[v_i, 1] = T.floormod(v_i - offset, divisor)

if "gpu" in tvm.target.Target(target).keys:
sch = tvm.tir.Schedule(func)
sch.bind(sch.get_loops("A")[0], "threadIdx.x")
func = sch.mod["main"]

built = tvm.build(func, target=target)

a_dev = tvm.nd.empty([N, 2], "int32", dev)
built(a_dev)
a = a_dev.numpy()

np.testing.assert_array_equal(a[:, 0], (np.arange(N) - offset) // divisor)
np.testing.assert_array_equal(a[:, 1], (np.arange(N) - offset) % divisor)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 68c917d

Please sign in to comment.