Skip to content

Commit

Permalink
[CodeGen][CUDA] Enhance CUDA codegen for SelectNode
Browse files Browse the repository at this point in the history
- This patch allows CUDA backend to emit correct code for
  selects with vector conditions, which may be produced
  by floordiv op lowering etc..

- This already works for llvm BE, as llvm select instruction
  supports vector conditions.

Signed-off-by: Wei Pan <[email protected]>
  • Loading branch information
wpan11nv committed Mar 4, 2020
1 parent 5a0f39b commit 290c58a
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 4 deletions.
4 changes: 4 additions & 0 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ class DataType {
bool is_vector() const {
return lanes() > 1;
}
/*! \return whether type is a bool vector type. */
bool is_vector_bool() const {
return is_vector() && bits() == 1;
}
/*!
* \brief Create a new data type by change lanes to a specified value.
* \param lanes The target number of lanes.
Expand Down
51 changes: 50 additions & 1 deletion src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
}
} else if (t == DataType::Bool()) {
os << "bool"; return;
} else if (t.is_vector_bool()) {
// CUDA does not support bool vectors.
// Use ushort vectors to represent instead.
int n = t.lanes();
if (n <= 4) {
os << "ushort" << n; return;
}
} else if (t.is_uint() || t.is_int()) {
if (t.is_uint()) {
if (t.lanes() != 1) {
Expand Down Expand Up @@ -226,7 +233,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
}

void CodeGenCUDA::PrintVecBinaryOp(
const std::string&op, DataType t,
const std::string& op, DataType t,
PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*)
// unpacking operations.
int lanes = t.lanes();
Expand Down Expand Up @@ -561,6 +568,48 @@ void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream &os) {
os << ')';
}

void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream &os) {
// Non-vector cases.
if (!op->dtype.is_vector()) {
CodeGenC::VisitExpr_(op, os);
return;
}

// Codegen vector condition case by serializing the select op.
CHECK(op->false_value->dtype == op->dtype &&
op->true_value->dtype == op->dtype &&
op->dtype.lanes() == op->condition.dtype().lanes());

int lanes = op->dtype.lanes();
int scope = BeginScope();

std::string c_var = SSAGetID(PrintExpr(op->condition), op->dtype);
std::string t_var = SSAGetID(PrintExpr(op->true_value), op->dtype);
std::string f_var = SSAGetID(PrintExpr(op->false_value), op->dtype);
std::string r_var = GetUniqueName("_");

this->PrintIndent();
this->PrintType(op->dtype, stream);
stream << ' ' << r_var << ";\n";

// The condition is stored as an ushort vector.
DataType memory_ty(DataType::TypeCode::kUInt, 16, lanes);

for (int i = 0; i < lanes; ++i) {
std::ostringstream item;
item << "(bool(";
PrintVecElemLoad(c_var, memory_ty, i, item);
item << ")?";
PrintVecElemLoad(t_var, op->dtype, i, item);
item << ':';
PrintVecElemLoad(f_var, op->dtype, i, item);
item << ')';
PrintVecElemStore(r_var, op->dtype, i, item.str());
}
os << r_var;
EndScope(scope);
}

inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
switch (op->dtype.bits()) {
case 64: case 32: {
Expand Down
5 changes: 3 additions & 2 deletions src/target/source/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ class CodeGenCUDA final : public CodeGenC {
return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
}
// override behavior
void VisitStmt_(const tir::ForNode* op) final;
void VisitStmt_(const ForNode* op) final;
void PrintStorageSync(const CallNode* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp(
const std::string&op, DataType t,
const std::string& op, DataType t,
PrimExpr lhs, PrimExpr rhs, std::ostream& os) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void PrintVecElemLoad(
Expand All @@ -58,6 +58,7 @@ class CodeGenCUDA final : public CodeGenC {
// overload visitor
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op, std::ostream& os) final;
void VisitExpr_(const CallNode *op, std::ostream& os) final;
Expand Down
30 changes: 29 additions & 1 deletion tests/python/unittest/test_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,33 @@ def check_cuda(dtype, m=32, n=32):
check_cuda("float32")
check_cuda("float16")

def test_cuda_floordiv_with_vectorization():
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
return

with tvm.target.cuda():
# B[i] = A[floordiv(i, k)]
n = 256
k = 37
A = te.placeholder((n,), name='A')
B = te.compute((n,), lambda i: A[tvm.tir.floordiv(i, k)], name='B')
s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], nparts=1)
xio, xii = s[B].split(xi, factor=4)
s[B].vectorize(xii)
s[B].bind(xo, bx)
s[B].bind(xio, tx)
func = tvm.build(s, [A, B], 'cuda')

ctx = tvm.gpu(0)
a_np = np.random.uniform(size=(n,)).astype(A.dtype)
b_np = np.array([a_np[i//k] for i in range(0, n)])
a_nd = tvm.nd.array(a_np, ctx)
b_nd = tvm.nd.array(np.zeros(b_np.shape, dtype=b_np.dtype), ctx)
func(a_nd, b_nd)
tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3)

if __name__ == "__main__":
test_cuda_vectorize_add()
test_cuda_multiply_add()
Expand All @@ -331,4 +358,5 @@ def check_cuda(dtype, m=32, n=32):
test_cuda_reducition_binding()
test_rfactor_predicates()
test_cuda_const_float_to_half()
test_cuda_reduction()
test_cuda_reduction()
test_cuda_floordiv_with_vectorization()

0 comments on commit 290c58a

Please sign in to comment.