Skip to content

Commit

Permalink
Add shuffle support to TVM (#3633)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jian Weng authored and tqchen committed Aug 1, 2019
1 parent 9ae01e0 commit a279dd0
Show file tree
Hide file tree
Showing 12 changed files with 141 additions and 5 deletions.
1 change: 1 addition & 0 deletions include/tvm/ir_functor_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
IR_EXPR_FUNCTOR_DISPATCH(Not);
IR_EXPR_FUNCTOR_DISPATCH(Select);
IR_EXPR_FUNCTOR_DISPATCH(Ramp);
IR_EXPR_FUNCTOR_DISPATCH(Shuffle);
IR_EXPR_FUNCTOR_DISPATCH(Broadcast);
IR_EXPR_FUNCTOR_DISPATCH(IntImm);
IR_EXPR_FUNCTOR_DISPATCH(UIntImm);
Expand Down
1 change: 1 addition & 0 deletions include/tvm/ir_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class TVM_DLL IRVisitor {
virtual void Visit_(const Not* op);
virtual void Visit_(const Select* op);
virtual void Visit_(const Ramp* op);
virtual void Visit_(const Shuffle* op);
virtual void Visit_(const Broadcast* op);
virtual void Visit_(const AssertStmt* op);
virtual void Visit_(const ProducerConsumer* op);
Expand Down
1 change: 1 addition & 0 deletions src/codegen/build_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define TVM_CODEGEN_BUILD_COMMON_H_

#include <tvm/codegen.h>
#include <tvm/ir.h>
#include <unordered_map>
#include <string>
#include "../runtime/meta_data.h"
Expand Down
4 changes: 4 additions & 0 deletions src/codegen/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,10 @@ void CodeGenC::VisitExpr_(const Ramp* op, std::ostream& os) { // NOLINT(*)
os << "))";
}

void CodeGenC::VisitExpr_(const Shuffle* op, std::ostream& os) {
LOG(FATAL) << "Shuffle: not supported ";
}

void CodeGenC::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Broadcast: not supported ";
}
Expand Down
1 change: 1 addition & 0 deletions src/codegen/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class CodeGenC :
void VisitExpr_(const Not* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Select* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Ramp* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Shuffle* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Broadcast* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const IntImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const UIntImm* op, std::ostream& os) override; // NOLINT(*)
Expand Down
23 changes: 20 additions & 3 deletions src/codegen/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,15 @@ void CodeGenCUDA::PrintVecBinaryOp(

void CodeGenCUDA::PrintVecElemLoad(
const std::string& vec, Type t, int i, std::ostream& os) { // NOLINT(*)
const char access[] = {'x', 'y', 'z', 'w'};
static const char access[] = {'x', 'y', 'z', 'w'};
CHECK(i >= 0 && i < 4);
os << vec << "." << access[i];
}

void CodeGenCUDA::PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value) {
this->PrintIndent();
const char access[] = {'x', 'y', 'z', 'w'};
static const char access[] = {'x', 'y', 'z', 'w'};
CHECK(i >= 0 && i < 4);
stream << vec << "." << access[i] << " = " << value << ";\n";
}
Expand Down Expand Up @@ -308,14 +308,31 @@ void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLIN
std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->type, os);
os << "(";
os << '(';
for (int i = 0; i < op->lanes; ++i) {
if (i != 0) os << ", ";
os << v;
}
os << ')';
}

void CodeGenCUDA::VisitExpr_(const Shuffle* op, std::ostream &os) {
std::vector<std::string> to_shuffle(op->vectors.size());
for (int i = 0, e = op->vectors.size(); i < e; ++i) {
CHECK(op->vectors[i].type().lanes() == 1) << "Only scalars can be shuffled in CUDA!";
to_shuffle[i] = PrintExpr(op->vectors[i]);
}
os << "make_";
PrintType(op->type, os);
os << '(';
for (int i = 0, e = op->indices.size(); i < e; ++i) {
const int64_t *val = as_const_int(op->indices[i]);
CHECK(val && *val >= 0 && (int) *val < (int) to_shuffle.size());
if (i != 0) os << ", ";
os << to_shuffle[*val];
}
os << ')';
}

inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
switch (op->type.bits()) {
Expand Down
1 change: 1 addition & 0 deletions src/codegen/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class CodeGenCUDA final : public CodeGenC {
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
// overload visitor
void VisitExpr_(const Ramp* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Shuffle* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImm *op, std::ostream& os) final;
void VisitStmt_(const Evaluate *op) final;
Expand Down
29 changes: 27 additions & 2 deletions src/codegen/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include "codegen_llvm.h"
#include "codegen_cpu.h"
#include "../build_common.h"
#include "../../pass/ir_util.h"
#include "../../arithmetic/compute_expr.h"

Expand Down Expand Up @@ -446,6 +447,7 @@ llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
if (extent == num_elems && begin == 0) return vec;
CHECK(begin >= 0 && extent <= num_elems) << "Slicing out of bound!\n";
std::vector<llvm::Constant*> indices;
indices.reserve(extent);
for (int i = 0; i < extent; ++i) {
Expand Down Expand Up @@ -481,6 +483,7 @@ llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
// concat vector, tree shape reduction
int total_lanes = 0;

for (llvm::Value* v : vecs) {
total_lanes += static_cast<int>(
v->getType()->getVectorNumElements());
Expand Down Expand Up @@ -652,12 +655,14 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
CHECK_GE(op->args.size(), 2U);
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
op->args[0].as<UIntImm>()->value);
uint64_t num_signature = op->args[1].as<UIntImm>()->value;
const uint64_t *num_signature = as_const_uint(op->args[1]);
CHECK(num_signature) << "The second argument should be a uint represents number of arguments, "
<< "but " << op->args[1] << " got!\n";
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> sig_type;
for (size_t i = 2; i < op->args.size(); ++i) {
arg_value.push_back(MakeValue(op->args[i]));
if (i - 2 < num_signature) {
if (i - 2 < *num_signature) {
sig_type.push_back(arg_value.back()->getType());
}
}
Expand Down Expand Up @@ -1002,6 +1007,26 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) {
return vec;
}

llvm::Value* CodeGenLLVM::VisitExpr_(const Shuffle* op) {
std::vector<llvm::Value *> vecs(op->vectors.size());
int total_lanes = 0;
for (int i = 0, e = op->vectors.size(); i < e; ++i) {
vecs[i] = VisitExpr(op->vectors[i]);
total_lanes += op->vectors[i].type().lanes();
}
llvm::Value* v0 = CreateVecConcat(vecs);
std::vector<uint32_t> idx(op->indices.size());
for (int i = 0, e = op->indices.size(); i < e; ++i) {
const int64_t *val = as_const_int(op->indices[i]);
CHECK(val && *val >= 0 && *val < total_lanes) << "Shuffled indeces are suppose to be int, "
<< "but get " << op->indices[i] << "\n";
idx[i] = *val;
}
llvm::Value* mask = llvm::ConstantDataVector::get(builder_->getContext(), idx);
auto res = builder_->CreateShuffleVector(v0, llvm::UndefValue::get(v0->getType()), mask);
return res;
}

llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) {
return CreateBroadcast(MakeValue(op->value), op->lanes);
}
Expand Down
1 change: 1 addition & 0 deletions src/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class CodeGenLLVM :
llvm::Value* VisitExpr_(const Load* op) override;
llvm::Value* VisitExpr_(const Call* op) override;
llvm::Value* VisitExpr_(const Ramp* op) override;
llvm::Value* VisitExpr_(const Shuffle* op) override;
llvm::Value* VisitExpr_(const Broadcast* op) override;
// stmt
void VisitStmt_(const Store* op) override;
Expand Down
8 changes: 8 additions & 0 deletions src/pass/ir_visitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,13 @@ void IRVisitor::Visit_(const Ramp *op) {
this->Visit(op->stride);
}

void IRVisitor::Visit_(const Shuffle *op) {
for (const auto &elem : op->indices)
this->Visit(elem);
for (const auto &elem : op->vectors)
this->Visit(elem);
}

void IRVisitor::Visit_(const Broadcast *op) {
this->Visit(op->value);
}
Expand Down Expand Up @@ -269,6 +276,7 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.DISPATCH_TO_VISIT(Not)
.DISPATCH_TO_VISIT(Select)
.DISPATCH_TO_VISIT(Ramp)
.DISPATCH_TO_VISIT(Shuffle)
.DISPATCH_TO_VISIT(Broadcast)
.DISPATCH_TO_VISIT(AssertStmt)
.DISPATCH_TO_VISIT(ProducerConsumer)
Expand Down
44 changes: 44 additions & 0 deletions tests/python/unittest/test_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,53 @@ def check_inf_nan(ctx, n, value, dtype):
check_inf_nan(ctx, 1, float('nan'), 'float64')


def test_cuda_shuffle():
if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
print("skip because cuda is not enabled..")
return

a = tvm.placeholder((64, ), 'int32')
b = tvm.placeholder((64, ), 'int32')
c = tvm.compute((64, ), lambda x: a[x] + b[x - (x % 4) + (3 - x % 4)])
sch = tvm.create_schedule(c.op)
x = c.op.axis[0]
xo, xi = sch[c].split(x, 4)
thrx = tvm.thread_axis("threadIdx.x")
sch[c].bind(xo, thrx)
sch[c].vectorize(xi)

def my_vectorize(stmt):
def vectorizer(op):
if op.for_type == tvm.stmt.For.Vectorized:
four = tvm.const(4, 'int32')
idx = tvm.make.Ramp(thrx.var * four, tvm.const(1, 'int32'), 4)
all_ones = tvm.const(1, 'int32x4')
store = op.body
value = store.value
new_a = tvm.make.Load('int32x4', value.a.buffer_var, idx, all_ones)
bs, ids = [], []
for i in range(4):
bs.append(tvm.make.Load('int32', value.b.buffer_var, thrx.var * four + tvm.const(i, 'int32')))
ids.append(tvm.const(3 - i, 'int32'))
new_b = tvm.make.Shuffle(bs, ids)
return tvm.make.Store(store.buffer_var, new_a + new_b, idx, all_ones)
return None
return tvm.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])

with tvm.build_config(add_lower_pass=[(1, my_vectorize)]):
module = tvm.build(sch, [a, b, c], target='cuda')
a_ = np.array(list(range(64)), dtype='int32')
b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32')
c_ = np.zeros((64, ), dtype='int32')
ref = a_ + np.array((list(range(4))) * 16, dtype='int32')
nda, ndb, ndc = [tvm.ndarray.array(i, tvm.gpu(0)) for i in [a_, b_, c_]]
module(nda, ndb, ndc)
tvm.testing.assert_allclose(ndc.asnumpy(), ref)

if __name__ == "__main__":
test_cuda_vectorize_add()
test_cuda_multiply_add()
test_cuda_vectorize_load()
test_cuda_make_int8x4()
test_cuda_inf_nan()
test_cuda_shuffle()
32 changes: 32 additions & 0 deletions tests/python/unittest/test_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,37 @@ def check_llvm_ir():
check_llvm_object()
check_llvm_ir()


def test_llvm_shuffle():
a = tvm.placeholder((8, ), 'int32')
b = tvm.placeholder((8, ), 'int32')
c = tvm.compute((8, ), lambda x: a[x] + b[7-x])
sch = tvm.create_schedule(c.op)

def my_vectorize(stmt):

def vectorizer(op):
store = op.body
idx = tvm.make.Ramp(tvm.const(0, 'int32'), tvm.const(1, 'int32'), 8)
all_ones = tvm.const(1, 'int32x8')
value = store.value
b_idx = tvm.make.Shuffle([idx], [tvm.const(i, 'int32') for i in range(7, -1, -1)])
new_a = tvm.make.Load('int32x8', value.a.buffer_var, idx, all_ones)
new_b = tvm.make.Load('int32x8', value.b.buffer_var, b_idx, all_ones)
value = new_a + new_b
return tvm.make.Store(store.buffer_var, new_a + new_b, idx, all_ones)

return tvm.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])

with tvm.build_config(add_lower_pass=[(1, my_vectorize)]):
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
module = tvm.build(sch, [a, b, c])
a_ = tvm.ndarray.array(np.arange(1, 9, dtype='int32'))
b_ = tvm.ndarray.array(np.arange(8, 0, -1, dtype='int32'))
c_ = tvm.ndarray.array(np.zeros((8, ), dtype='int32'))
module(a_, b_, c_)
tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32'))

if __name__ == "__main__":
test_llvm_import()
test_alignment()
Expand All @@ -567,3 +598,4 @@ def check_llvm_ir():
test_llvm_div()
test_llvm_fp_math()
test_dwarf_debug_information()
test_llvm_shuffle()

0 comments on commit a279dd0

Please sign in to comment.