Skip to content

Commit

Permalink
Merge branch 'tilelang' of github.com:TileLang/tvm into tilelang
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Oct 27, 2024
2 parents b322866 + e2f6bf8 commit bccac9f
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 28 deletions.
23 changes: 21 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -370,10 +370,27 @@ list(APPEND COMPILER_SRCS "src/target/datatype/myfloat/myfloat.cc")
tvm_file_glob(GLOB TILE_LIBRARY_SRCS
src/tl/*.cc
src/tl/layout/*.cc
src/tl/target/*.cc
src/tl/transform/*.cc
src/tl/op/*.cc
src/tl/target/utils.cc
)

if(USE_CUDA)
tvm_file_glob(GLOB TILE_LIBRARY_CUDA_SRCS
src/tl/target/codegen_cuda.cc
src/tl/target/rt_mod_cuda.cc
)
list(APPEND TILE_LIBRARY_SRCS ${TILE_LIBRARY_CUDA_SRCS})
endif(USE_CUDA)

if(USE_ROCM)
tvm_file_glob(GLOB TILE_LIBRARY_HIP_SRCS
src/tl/target/codegen_rocm.cc
src/tl/target/rt_mod_rocm.cc
)
list(APPEND TILE_LIBRARY_SRCS ${TILE_LIBRARY_HIP_SRCS})
endif(USE_ROCM)

list(APPEND COMPILER_SRCS ${TILE_LIBRARY_SRCS})

tvm_file_glob(GLOB RUNTIME_SRCS
Expand All @@ -385,7 +402,9 @@ tvm_file_glob(GLOB RUNTIME_SRCS
src/runtime/relax_vm/*.cc
)
if(USE_CUDA)
list(APPEND TILELANG_RUNTIME_SRCS src/tl/runtime/*.cc)
tvm_file_glob(GLOB TILELANG_RUNTIME_SRCS
src/tl/runtime/*.cc
)
list(APPEND RUNTIME_SRCS ${TILELANG_RUNTIME_SRCS})
endif(USE_CUDA)

Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

TVM.TL
==============================================
**Not the official and Release, but you can use tl via [BitBLAS](https://github.com/microsoft/BitBLAS)**

TVM.TL is an extention of TVMScript to write simple and high performance GPU kernels with tensorcores. TVM.TL is currently supported on CUDA deivces with Ampere (sm_80+), Turing (sm_75) and Volta(sm_70).

Let's get started with a simple GEMM example.
Expand Down
137 changes: 116 additions & 21 deletions src/tl/transform/loop_vectorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ namespace tl {

using namespace tir;

struct VectorizePlanResult {
int vector_size;
bool dynamic;
PrimExpr condition;
};

class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
public:
VectorizePlanner() = default;
Expand All @@ -51,6 +57,14 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
return vector_size_;
}

bool GetDynamic() {
return dynamic_;
}

PrimExpr GetCondition() {
return condition_;
}

private:
void VisitStmt_(const ForNode* node) final {
inner_for_ = node;
Expand Down Expand Up @@ -107,12 +121,20 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
int max_vector_size = arith::ZeroAwareGCD(128 / access_type.bits(), extent_ptr->value);

auto mod_set = analyzer_.modular_set(buffer->shape.back());
max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->base);
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);
while (!IndiceCanVectorize(buffer.OffsetOf(indices).back(), inner_for_->loop_var,
inner_for_->extent, vector_size_, &analyzer_)) {
vector_size_ /= 2;
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block conditionally tail vectorize
if (buffer->shape.back().as<IntImmNode>()) {
max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->base);
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);
while (!IndiceCanVectorize(buffer.OffsetOf(indices).back(), inner_for_->loop_var,
inner_for_->extent, vector_size_, &analyzer_)) {
vector_size_ /= 2;
}
} else if (vector_size_ <= vector_load_bits_max_ / buffer->dtype.bits()) {
// dynamic shape load: get the vectorization condition
dynamic_ = true;
PrimExpr offset = buffer.OffsetOf(indices).back();
condition_ = (FloorMod(offset, vector_size_) == 0);
}
}

Expand All @@ -122,11 +144,39 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
Map<Var, Range> iter_map_;
bool has_nonlocal_memory_access_ = false;
int vector_size_ = 128;
// conditionally vectorize
bool dynamic_ = false;
PrimExpr condition_;
};

class VectorizeDynamicCallRemover : public StmtExprMutator {
public:
VectorizeDynamicCallRemover(Var inner_var, int vector_size):
inner_var_(inner_var), vector_size_(vector_size) {}
private:
PrimExpr VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::if_then_else())) {
PrimExpr cond = this->VisitExpr(op->args[0]);
Map<Var, PrimExpr> vmap;
// Currently remove upper bound check
vmap.Set(inner_var_, 0);
cond = Substitute(cond, vmap);
Array<PrimExpr> new_args{cond, op->args[1], op->args[2]};
return Call(op->dtype, op->op, new_args, op->span);
} else {
// TODO: For other calls
return GetRef<PrimExpr>(op);
}
}

Var inner_var_;
int vector_size_;
};

class VectorizeRewriter : public StmtExprMutator {
public:
VectorizeRewriter(int vector_size) : vector_size_(vector_size) {}
VectorizeRewriter(VectorizePlanResult plan):
vector_size_(plan.vector_size), condition_(plan.condition), dynamic_(plan.dynamic) {}

private:
Stmt VisitStmt_(const ForNode* node) final {
Expand All @@ -140,19 +190,51 @@ class VectorizeRewriter : public StmtExprMutator {
int extent = *extent_ptr;
ICHECK(extent % vector_size_ == 0);
ICHECK(is_zero(fnode->min));
if (extent == vector_size_) {
fnode.CopyOnWrite()->kind = ForKind::kVectorized;
return fnode;
if (!dynamic_) { // check dynamic shape
if (extent == vector_size_) {
fnode.CopyOnWrite()->kind = ForKind::kVectorized;
return fnode;
} else {
Var inner_var = Var("vec");
Var outer_var = Var(old_var->name_hint);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body, fnode->thread_binding,
fnode->annotations, fnode->span);
return body;
}
} else {
Var inner_var = Var("vec");
Var outer_var = Var(old_var->name_hint);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body, fnode->thread_binding,
fnode->annotations, fnode->span);
return body;
if (extent == vector_size_) {
// add condition ifthenelse here
For vectorize_for = fnode;
vectorize_for.CopyOnWrite()->kind = ForKind::kVectorized;
For serial_for = fnode;
serial_for.CopyOnWrite()->kind = ForKind::kSerial;
Stmt body = IfThenElse(condition_, vectorize_for, serial_for);
return body;
} else {
Var inner_var = Var("vec");
Var outer_var = Var(old_var->name_hint);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
// add condition ifthenelse here
Map<Var, PrimExpr> vmap_condition;
vmap_condition.Set(fnode->loop_var, outer_var * vector_size_);
PrimExpr condition = Substitute(condition_, vmap_condition);

VectorizeDynamicCallRemover remover(inner_var, vector_size_);
body = remover(body);

For vectorize_for = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
For serial_for = For(inner_var, 0, vector_size_, ForKind::kSerial, body);
body = IfThenElse(condition, vectorize_for, serial_for);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body, fnode->thread_binding,
fnode->annotations, fnode->span);
return body;
}
}
} else {
return ret;
Expand All @@ -161,10 +243,20 @@ class VectorizeRewriter : public StmtExprMutator {

const ForNode* inner_for_;
const int vector_size_;
const PrimExpr condition_;
const bool dynamic_;
};

int GetVectorizeSize(const For& loop) { return VectorizePlanner().Plan(loop); }

VectorizePlanResult GetVectorizePlanResult(const For& loop) {
VectorizePlanner planner;
int vector_size = planner.Plan(loop);
bool dynamic = planner.GetDynamic();
PrimExpr condition = planner.GetCondition();
return {vector_size, dynamic, condition};
}

// Use the same code as tir.transform.vectorize_loop
class VectorizeChecker : public ExprMutator {
public:
Expand Down Expand Up @@ -444,11 +536,14 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, int targ
}

For VectorizeLoop(const For& loop, int vectorize_hint) {
VectorizePlanResult res{128, false, 0};
if (vectorize_hint <= 0) {
vectorize_hint = GetVectorizeSize(loop);
res = GetVectorizePlanResult(loop);
vectorize_hint = res.vector_size;
// vectorize_hint = GetVectorizeSize(loop);
}
if (vectorize_hint == 1) return loop;
auto rewriter = VectorizeRewriter(vectorize_hint);
auto rewriter = VectorizeRewriter(res);
return Downcast<For>(rewriter(loop));
}

Expand Down
17 changes: 12 additions & 5 deletions src/tl/transform/pipeline_planning.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,15 @@ class PipelinePlanner : public StmtExprMutator {
pinfo.reads = std::move(access[0]);
pinfo.writes = std::move(access[1]);
pinfo.original_order = idx;
for (auto region : pinfo.reads)
if (region->buffer.scope() == "global") pinfo.copy_stage = true;
for (auto region : pinfo.writes)
if (region->buffer.scope() == "global") pinfo.copy_stage = true;

// copy stage should only have one reads and one writes
if (pinfo.reads.size() == 1 && pinfo.writes.size() == 1) {
for (auto region : pinfo.reads)
if (region->buffer.scope() == "global") pinfo.copy_stage = true;
for (auto region : pinfo.writes)
if (region->buffer.scope() == "global") pinfo.copy_stage = true;
}

return std::move(pinfo);
}

Expand Down Expand Up @@ -160,7 +165,9 @@ class PipelinePlanner : public StmtExprMutator {
}
}
}
ICHECK(size_t(order_idx) == pipeline_stage_infos.size());
ICHECK(size_t(order_idx) == pipeline_stage_infos.size()) <<
"The number of stages should be equal to the number of pipeline stages. " <<
"Got " << order_idx << " stages and " << pipeline_stage_infos.size() << " pipeline stages.";

// if all the copy is at the end of the order, we can move these copy to the begining of the
// order and shrink the stage offset by 1.
Expand Down

0 comments on commit bccac9f

Please sign in to comment.