From a279dd0e58913da45b0e19653b99bb9c21233e51 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Thu, 1 Aug 2019 12:52:33 -0700 Subject: [PATCH] Add shuffle support to TVM (#3633) --- include/tvm/ir_functor_ext.h | 1 + include/tvm/ir_visitor.h | 1 + src/codegen/build_common.h | 1 + src/codegen/codegen_c.cc | 4 ++ src/codegen/codegen_c.h | 1 + src/codegen/codegen_cuda.cc | 23 +++++++++-- src/codegen/codegen_cuda.h | 1 + src/codegen/llvm/codegen_llvm.cc | 29 +++++++++++++- src/codegen/llvm/codegen_llvm.h | 1 + src/pass/ir_visitor.cc | 8 ++++ tests/python/unittest/test_codegen_cuda.py | 44 ++++++++++++++++++++++ tests/python/unittest/test_codegen_llvm.py | 32 ++++++++++++++++ 12 files changed, 141 insertions(+), 5 deletions(-) diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index d5b27a6a9c63..a7d91eacf851 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -199,6 +199,7 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(Not); IR_EXPR_FUNCTOR_DISPATCH(Select); IR_EXPR_FUNCTOR_DISPATCH(Ramp); + IR_EXPR_FUNCTOR_DISPATCH(Shuffle); IR_EXPR_FUNCTOR_DISPATCH(Broadcast); IR_EXPR_FUNCTOR_DISPATCH(IntImm); IR_EXPR_FUNCTOR_DISPATCH(UIntImm); diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h index c36f631d9838..f20b91368587 100644 --- a/include/tvm/ir_visitor.h +++ b/include/tvm/ir_visitor.h @@ -131,6 +131,7 @@ class TVM_DLL IRVisitor { virtual void Visit_(const Not* op); virtual void Visit_(const Select* op); virtual void Visit_(const Ramp* op); + virtual void Visit_(const Shuffle* op); virtual void Visit_(const Broadcast* op); virtual void Visit_(const AssertStmt* op); virtual void Visit_(const ProducerConsumer* op); diff --git a/src/codegen/build_common.h b/src/codegen/build_common.h index 713922b1e326..0bb4002a2f18 100644 --- a/src/codegen/build_common.h +++ b/src/codegen/build_common.h @@ -26,6 +26,7 @@ #define TVM_CODEGEN_BUILD_COMMON_H_ #include +#include #include #include #include "../runtime/meta_data.h" diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 395d3f3178c6..043c64702edc 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -728,6 +728,10 @@ void CodeGenC::VisitExpr_(const Ramp* op, std::ostream& os) { // NOLINT(*) os << "))"; } +void CodeGenC::VisitExpr_(const Shuffle* op, std::ostream& os) { + LOG(FATAL) << "Shuffle: not supported "; +} + void CodeGenC::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Broadcast: not supported "; } diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index 5e84cd945bc5..5cd30f1c5074 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -126,6 +126,7 @@ class CodeGenC : void VisitExpr_(const Not* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Select* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Ramp* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Shuffle* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Broadcast* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const IntImm* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const UIntImm* op, std::ostream& os) override; // NOLINT(*) diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index a32473158bd5..d13b2c99c3bc 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -205,7 +205,7 @@ void CodeGenCUDA::PrintVecBinaryOp( void CodeGenCUDA::PrintVecElemLoad( const std::string& vec, Type t, int i, std::ostream& os) { // NOLINT(*) - const char access[] = {'x', 'y', 'z', 'w'}; + static const char access[] = {'x', 'y', 'z', 'w'}; CHECK(i >= 0 && i < 4); os << vec << "." << access[i]; } @@ -213,7 +213,7 @@ void CodeGenCUDA::PrintVecElemLoad( void CodeGenCUDA::PrintVecElemStore( const std::string& vec, Type t, int i, const std::string& value) { this->PrintIndent(); - const char access[] = {'x', 'y', 'z', 'w'}; + static const char access[] = {'x', 'y', 'z', 'w'}; CHECK(i >= 0 && i < 4); stream << vec << "." << access[i] << " = " << value << ";\n"; } @@ -308,7 +308,7 @@ void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLIN std::string v = PrintExpr(op->value); os << "make_"; PrintType(op->type, os); - os << "("; + os << '('; for (int i = 0; i < op->lanes; ++i) { if (i != 0) os << ", "; os << v; @@ -316,6 +316,23 @@ void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLIN os << ')'; } +void CodeGenCUDA::VisitExpr_(const Shuffle* op, std::ostream &os) { + std::vector to_shuffle(op->vectors.size()); + for (int i = 0, e = op->vectors.size(); i < e; ++i) { + CHECK(op->vectors[i].type().lanes() == 1) << "Only scalars can be shuffled in CUDA!"; + to_shuffle[i] = PrintExpr(op->vectors[i]); + } + os << "make_"; + PrintType(op->type, os); + os << '('; + for (int i = 0, e = op->indices.size(); i < e; ++i) { + const int64_t *val = as_const_int(op->indices[i]); + CHECK(val && *val >= 0 && (int) *val < (int) to_shuffle.size()); + if (i != 0) os << ", "; + os << to_shuffle[*val]; + } + os << ')'; +} inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*) switch (op->type.bits()) { diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index acd759f33889..61c6fa3a5170 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -57,6 +57,7 @@ class CodeGenCUDA final : public CodeGenC { void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) // overload visitor void VisitExpr_(const Ramp* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const Shuffle* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImm *op, std::ostream& os) final; void VisitStmt_(const Evaluate *op) final; diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 537de1efc0f4..5bc415f2c510 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -30,6 +30,7 @@ #include "codegen_llvm.h" #include "codegen_cpu.h" +#include "../build_common.h" #include "../../pass/ir_util.h" #include "../../arithmetic/compute_expr.h" @@ -446,6 +447,7 @@ llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) { llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) { int num_elems = static_cast(vec->getType()->getVectorNumElements()); if (extent == num_elems && begin == 0) return vec; + CHECK(begin >= 0 && extent <= num_elems) << "Slicing out of bound!\n"; std::vector indices; indices.reserve(extent); for (int i = 0; i < extent; ++i) { @@ -481,6 +483,7 @@ llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) { llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector vecs) { // concat vector, tree shape reduction int total_lanes = 0; + for (llvm::Value* v : vecs) { total_lanes += static_cast( v->getType()->getVectorNumElements()); @@ -652,12 +655,14 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { CHECK_GE(op->args.size(), 2U); llvm::Intrinsic::ID id = static_cast( op->args[0].as()->value); - uint64_t num_signature = op->args[1].as()->value; + const uint64_t *num_signature = as_const_uint(op->args[1]); + CHECK(num_signature) << "The second argument should be a uint represents number of arguments, " + << "but " << op->args[1] << " got!\n"; std::vector arg_value; std::vector sig_type; for (size_t i = 2; i < op->args.size(); ++i) { arg_value.push_back(MakeValue(op->args[i])); - if (i - 2 < num_signature) { + if (i - 2 < *num_signature) { sig_type.push_back(arg_value.back()->getType()); } } @@ -1002,6 +1007,26 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) { return vec; } +llvm::Value* CodeGenLLVM::VisitExpr_(const Shuffle* op) { + std::vector vecs(op->vectors.size()); + int total_lanes = 0; + for (int i = 0, e = op->vectors.size(); i < e; ++i) { + vecs[i] = VisitExpr(op->vectors[i]); + total_lanes += op->vectors[i].type().lanes(); + } + llvm::Value* v0 = CreateVecConcat(vecs); + std::vector idx(op->indices.size()); + for (int i = 0, e = op->indices.size(); i < e; ++i) { + const int64_t *val = as_const_int(op->indices[i]); + CHECK(val && *val >= 0 && *val < total_lanes) << "Shuffled indeces are suppose to be int, " + << "but get " << op->indices[i] << "\n"; + idx[i] = *val; + } + llvm::Value* mask = llvm::ConstantDataVector::get(builder_->getContext(), idx); + auto res = builder_->CreateShuffleVector(v0, llvm::UndefValue::get(v0->getType()), mask); + return res; +} + llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) { return CreateBroadcast(MakeValue(op->value), op->lanes); } diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index 6be860aec379..13dc6acee27c 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -131,6 +131,7 @@ class CodeGenLLVM : llvm::Value* VisitExpr_(const Load* op) override; llvm::Value* VisitExpr_(const Call* op) override; llvm::Value* VisitExpr_(const Ramp* op) override; + llvm::Value* VisitExpr_(const Shuffle* op) override; llvm::Value* VisitExpr_(const Broadcast* op) override; // stmt void VisitStmt_(const Store* op) override; diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc index dd469f45c662..fde183e0c41a 100644 --- a/src/pass/ir_visitor.cc +++ b/src/pass/ir_visitor.cc @@ -177,6 +177,13 @@ void IRVisitor::Visit_(const Ramp *op) { this->Visit(op->stride); } +void IRVisitor::Visit_(const Shuffle *op) { + for (const auto &elem : op->indices) + this->Visit(elem); + for (const auto &elem : op->vectors) + this->Visit(elem); +} + void IRVisitor::Visit_(const Broadcast *op) { this->Visit(op->value); } @@ -269,6 +276,7 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) .DISPATCH_TO_VISIT(Not) .DISPATCH_TO_VISIT(Select) .DISPATCH_TO_VISIT(Ramp) +.DISPATCH_TO_VISIT(Shuffle) .DISPATCH_TO_VISIT(Broadcast) .DISPATCH_TO_VISIT(AssertStmt) .DISPATCH_TO_VISIT(ProducerConsumer) diff --git a/tests/python/unittest/test_codegen_cuda.py b/tests/python/unittest/test_codegen_cuda.py index 8fe6720830a5..e8439de6089b 100644 --- a/tests/python/unittest/test_codegen_cuda.py +++ b/tests/python/unittest/test_codegen_cuda.py @@ -154,9 +154,53 @@ def check_inf_nan(ctx, n, value, dtype): check_inf_nan(ctx, 1, float('nan'), 'float64') +def test_cuda_shuffle(): + if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"): + print("skip because cuda is not enabled..") + return + + a = tvm.placeholder((64, ), 'int32') + b = tvm.placeholder((64, ), 'int32') + c = tvm.compute((64, ), lambda x: a[x] + b[x - (x % 4) + (3 - x % 4)]) + sch = tvm.create_schedule(c.op) + x = c.op.axis[0] + xo, xi = sch[c].split(x, 4) + thrx = tvm.thread_axis("threadIdx.x") + sch[c].bind(xo, thrx) + sch[c].vectorize(xi) + + def my_vectorize(stmt): + def vectorizer(op): + if op.for_type == tvm.stmt.For.Vectorized: + four = tvm.const(4, 'int32') + idx = tvm.make.Ramp(thrx.var * four, tvm.const(1, 'int32'), 4) + all_ones = tvm.const(1, 'int32x4') + store = op.body + value = store.value + new_a = tvm.make.Load('int32x4', value.a.buffer_var, idx, all_ones) + bs, ids = [], [] + for i in range(4): + bs.append(tvm.make.Load('int32', value.b.buffer_var, thrx.var * four + tvm.const(i, 'int32'))) + ids.append(tvm.const(3 - i, 'int32')) + new_b = tvm.make.Shuffle(bs, ids) + return tvm.make.Store(store.buffer_var, new_a + new_b, idx, all_ones) + return None + return tvm.ir_pass.IRTransform(stmt, None, vectorizer, ['For']) + + with tvm.build_config(add_lower_pass=[(1, my_vectorize)]): + module = tvm.build(sch, [a, b, c], target='cuda') + a_ = np.array(list(range(64)), dtype='int32') + b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32') + c_ = np.zeros((64, ), dtype='int32') + ref = a_ + np.array((list(range(4))) * 16, dtype='int32') + nda, ndb, ndc = [tvm.ndarray.array(i, tvm.gpu(0)) for i in [a_, b_, c_]] + module(nda, ndb, ndc) + tvm.testing.assert_allclose(ndc.asnumpy(), ref) + if __name__ == "__main__": test_cuda_vectorize_add() test_cuda_multiply_add() test_cuda_vectorize_load() test_cuda_make_int8x4() test_cuda_inf_nan() + test_cuda_shuffle() diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index ed6fedcb0d17..34dad36a9076 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -548,6 +548,37 @@ def check_llvm_ir(): check_llvm_object() check_llvm_ir() + +def test_llvm_shuffle(): + a = tvm.placeholder((8, ), 'int32') + b = tvm.placeholder((8, ), 'int32') + c = tvm.compute((8, ), lambda x: a[x] + b[7-x]) + sch = tvm.create_schedule(c.op) + + def my_vectorize(stmt): + + def vectorizer(op): + store = op.body + idx = tvm.make.Ramp(tvm.const(0, 'int32'), tvm.const(1, 'int32'), 8) + all_ones = tvm.const(1, 'int32x8') + value = store.value + b_idx = tvm.make.Shuffle([idx], [tvm.const(i, 'int32') for i in range(7, -1, -1)]) + new_a = tvm.make.Load('int32x8', value.a.buffer_var, idx, all_ones) + new_b = tvm.make.Load('int32x8', value.b.buffer_var, b_idx, all_ones) + value = new_a + new_b + return tvm.make.Store(store.buffer_var, new_a + new_b, idx, all_ones) + + return tvm.ir_pass.IRTransform(stmt, None, vectorizer, ['For']) + + with tvm.build_config(add_lower_pass=[(1, my_vectorize)]): + ir = tvm.lower(sch, [a, b, c], simple_mode=True) + module = tvm.build(sch, [a, b, c]) + a_ = tvm.ndarray.array(np.arange(1, 9, dtype='int32')) + b_ = tvm.ndarray.array(np.arange(8, 0, -1, dtype='int32')) + c_ = tvm.ndarray.array(np.zeros((8, ), dtype='int32')) + module(a_, b_, c_) + tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32')) + if __name__ == "__main__": test_llvm_import() test_alignment() @@ -567,3 +598,4 @@ def check_llvm_ir(): test_llvm_div() test_llvm_fp_math() test_dwarf_debug_information() + test_llvm_shuffle()