From f8718a1d1b2b46a3749a0991bd4e0012d6c79e85 Mon Sep 17 00:00:00 2001 From: LiangLiu <1432249204@qq.com> Date: Tue, 14 Apr 2020 11:35:31 +0800 Subject: [PATCH] [CODEGEN][CUDA] Fix vector load (#5226) * Fix high-low bit bug in __pack_half2 * Fix vector load * Add unit8 support for PrintVecElemLoadExpr and BroadcastNode --- src/target/source/codegen_c.cc | 40 +++++++++----- src/target/source/codegen_c.h | 2 + src/target/source/codegen_cuda.cc | 52 ++++++++++++++++++- src/target/source/codegen_cuda.h | 1 + src/target/source/literal/cuda_half_t.h | 2 +- .../unittest/test_target_codegen_cuda.py | 39 ++++++++++++++ 6 files changed, 121 insertions(+), 15 deletions(-) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 444dc996b10f..ac13f8a50091 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -668,15 +668,7 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base); HandleVolatileLoads(ref, op, os); } else { - // The assignment below introduces side-effect, and the resulting value cannot - // be reused across multiple expression, thus a new scope is needed - int vec_scope = BeginScope(); - - // load seperately. - std::string svalue = GetUniqueName("_"); - this->PrintIndent(); - this->PrintType(op->dtype, stream); - stream << ' ' << svalue << ";\n"; + std::ostringstream svalue_expr; std::string sindex = SSAGetID(PrintExpr(op->index), op->index.dtype()); std::string vid = GetVarID(op->buffer_var.get()); DataType elem_type = op->dtype.element_of(); @@ -699,10 +691,9 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) value_temp << '['; PrintVecElemLoad(sindex, op->index.dtype(), i, value_temp); value_temp << ']'; - PrintVecElemStore(svalue, op->dtype, i, value_temp.str()); + PrintVecElemLoadExpr(op->dtype, i, value_temp.str(), svalue_expr); } - os << svalue; - EndScope(vec_scope); + os << svalue_expr.str(); } } } @@ -955,5 +946,30 @@ void CodeGenC::VisitStmt_(const ProducerConsumerNode* op) { PrintStmt(op->body); } +void CodeGenC::PrintVecElemLoadExpr( + DataType t, int i, const std::string& value, std::ostream& os) { + CHECK_GT(t.lanes(), 1); + if (t.bits() == 8 && (t.is_int() || t.is_uint())) { + if (i != 0) { + os << "|"; + } + os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))"; + return; + } + + if (i == 0) { + os << "(("; + PrintType(t, os); + os << t.lanes() << ")("; + } + os << value; + if (i != t.lanes() - 1) { + os << ","; + } else { + os << "))"; + } + return; +} + } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 30ad890c923d..49139de2fd1c 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -191,6 +191,8 @@ class CodeGenC : const std::string& vec, DataType t, int i, const std::string& value); // Get a cast type from to virtual std::string CastFromTo(std::string value, DataType from, DataType target); + // Get load of single element with expression + virtual void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os); protected: // Print reference to struct location diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 9c4fc69a9d78..c7971cef1bf6 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -591,13 +591,17 @@ void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) { } void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) - if (op->dtype.is_int() && op->dtype.bits() == 8 && op->lanes == 4) { + if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && op->lanes == 4) { // make_int8x4 const int64_t *p = as_const_int(op->value); CHECK(p); int64_t v = *p & 0xFF; v = (v << 24) | (v << 16) | (v << 8) | v; - os << "(int)" << v; + if (op->dtype.is_uint()) { + os << "(uint)" << v; + } else { + os << "(int)" << v; + } return; } @@ -796,5 +800,49 @@ void CodeGenCUDA::HandleVolatileLoads(const std::string& value, } } +void CodeGenCUDA::PrintVecElemLoadExpr( + DataType t, int i, const std::string& value, std::ostream& os) { + CHECK_GT(t.lanes(), 1); + if (t.bits() == 8 && (t.is_int() || t.is_uint())) { + if (i != 0) { + os << "|"; + } + os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))"; + return; + } + + if (t.is_float16()) { + if (i == 0) { + os << "make_"; + PrintType(t, os); + os << '('; + } + if (i % 2 == 0) { + os << "__pack_half2(" << value; + } else { + os << "," << value << ")"; + if (i != t.lanes() - 1) { + os << ","; + } else { + os << ")"; + } + } + return; + } + + if (i == 0) { + os << "make_"; + PrintType(t, os); + os << "("; + } + os << value; + if (i != t.lanes() - 1) { + os << ","; + } else { + os << ")"; + } + return; +} + } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index 6ba748755d5b..d1db7047b1b6 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -55,6 +55,7 @@ class CodeGenCUDA final : public CodeGenC { void PrintVecElemStore( const std::string& vec, DataType t, int i, const std::string& value) final; void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) + void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final; // overload visitor void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*) diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index fd0652afb0d4..858ac8572a08 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -291,7 +291,7 @@ static inline __device__ __host__ unsigned __pack_half2(const half x, const half y) { unsigned v0 = *((unsigned short *)&x); unsigned v1 = *((unsigned short *)&y); - return (v0 << 16) | v1; + return (v1 << 16) | v0; } )"; diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index bb162f41861d..739fc6fda76d 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -543,6 +543,44 @@ def run_test(dtype): run_test("uint32") run_test("uint64") +def test_cuda_vectorize_load_permute_pad(): + def check_cuda(dtype, n, l, padding, lanes): + if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): + print("skip because cuda is not enabled..") + return + if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version): + print("Skip because gpu does not have fp16 support") + return + + ctx = tvm.gpu(0) + A = tvm.te.placeholder((n, l), name='A', dtype=dtype) + B = tvm.te.compute((n // lanes, l + 2 * padding, lanes), + lambda i, j, k: tvm.te.if_then_else( + tvm.te.any(j < padding, j >= l + padding), + tvm.runtime.convert(0).astype(dtype), A[i * lanes + k, j - padding]), + name='B') + s = te.create_schedule(B.op) + block, thread, vectorize = s[B].op.axis + s[B].bind(block, bx) + s[B].bind(thread, tx) + s[B].vectorize(vectorize) + fun = tvm.build(s, [A, B], "cuda", name="vector_load_permute_pad") + np_a = np.random.randint( + low=-128, high=127, size=(n, l)).astype(A.dtype) + a = tvm.nd.empty((n, l), A.dtype, ctx).copyfrom(np_a) + b = tvm.nd.empty((n // lanes, l + padding * 2, lanes), B.dtype, ctx) + fun(a, b) + np_a_reshape = np_a.reshape(n // lanes, lanes, l).transpose(0, 2, 1) + ref = np.pad(np_a_reshape, ((0, 0), (padding, padding), + (0, 0)), mode='constant', constant_values=0) + tvm.testing.assert_allclose(b.asnumpy(), ref) + + check_cuda("int8", 64, 16, 3, 4) + check_cuda("uint8", 64, 16, 3, 4) + check_cuda("int32", 64, 16, 3, 4) + check_cuda("float16", 64, 16, 3, 4) + check_cuda("float32", 64, 16, 3, 4) + if __name__ == "__main__": test_cuda_vectorize_add() test_cuda_multiply_add() @@ -560,3 +598,4 @@ def run_test(dtype): test_vectorized_intrin1() test_vectorized_intrin2() test_vectorized_popcount() + test_cuda_vectorize_load_permute_pad()