Skip to content

Commit

Permalink
testing 16x8x8 ldmatrix tensoriation
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent a3a4155 commit 4cf6b20
Showing 1 changed file with 89 additions and 4 deletions.
93 changes: 89 additions & 4 deletions src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ namespace tir {

// Visitor to find m in pattern
// store warp_mem[m * warp_index + (width * m) * y + x]
class WarpStoreCoeffFinder : private StmtVisitor {
class WarpStoreCoeffFinder : private StmtExprVisitor {
public:
WarpStoreCoeffFinder(const VarNode* buffer, Var warp_index, arith::Analyzer* analyzer)
: buffer_(buffer), warp_index_(warp_index), analyzer_(analyzer) {}
Expand All @@ -113,6 +113,14 @@ class WarpStoreCoeffFinder : private StmtVisitor {

private:
/// Visitor implementation
void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::ptx_ldmatrix()) && op->args[3].as<VarNode>() == buffer_) {
int num_matrix = op->args[1].as<IntImmNode>()->value;
warp_coeff_ = num_matrix * 2;
}
StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const StoreNode* op) final {
LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
}
Expand Down Expand Up @@ -142,6 +150,7 @@ class WarpStoreCoeffFinder : private StmtVisitor {
}

void UpdatePattern(const PrimExpr& index) {
LOG(INFO) << index;
Array<PrimExpr> m = arith::DetectLinearEquation(index, {warp_index_});
ICHECK_EQ(m.size(), 2U)
<< "LowerWarpMemory failed. Could not simplify the store index `" << index
Expand All @@ -157,7 +166,7 @@ class WarpStoreCoeffFinder : private StmtVisitor {

if (warp_coeff_ != 0) {
ICHECK_EQ(warp_coeff_, mcoeff_as_int->value)
<< "LowerWarpMemory failed due to two different store coefficient to warp index";
<< "LowerWarpMemory failed due to two different store coefficient to warp index" << warp_coeff_ << ", "<< mcoeff_as_int->value;
} else {
warp_coeff_ = mcoeff_as_int->value;
}
Expand All @@ -174,7 +183,7 @@ class WarpStoreCoeffFinder : private StmtVisitor {
};

// Visitor to find the warp index
class WarpIndexFinder : private StmtVisitor {
class WarpIndexFinder : private StmtExprVisitor {
public:
explicit WarpIndexFinder(int warp_size) : warp_size_(warp_size) {}
// find the warp co-efficient and the shuffle width in the statement
Expand Down Expand Up @@ -206,6 +215,7 @@ class WarpIndexFinder : private StmtVisitor {
} else {
width_ = value_as_int->value;
warp_index_ = iv;
// LOG(INFO) << "Using warp index from " << GetRef<AttrStmt>(op);
}
}
}
Expand All @@ -226,15 +236,26 @@ class WarpAccessRewriter : protected StmtExprMutator {
// Rewrite the allocate statement which transforms
// warp memory to local memory.
Stmt Rewrite(const AllocateNode* op) {
// LOG(INFO) << "Allocate " << GetRef<Allocate>(op);
// LOG(INFO) << op->buffer_var;
buffer_ = op->buffer_var.get();
int alloc_size = op->ConstantAllocationSize();
ICHECK_GT(alloc_size, 0) << "warp memory only support constant alloc size";
alloc_size *= op->dtype.lanes();
std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body);
LOG(INFO) << warp_index_ << ", " << width_;
LOG(INFO) << op->buffer_var;
warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body);
// if (warp_size_ == width_) {
// warp_coeff_ = 1;
// } else {
// LOG(INFO) << alloc_size << ", " << warp_size_ << ", " << width_;
// warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body);
// }

// Align the local memory size. The number of elements may not
// be a multiple of width_ * warp_coeff_; round it up.
LOG(INFO) << warp_coeff_;
int factor = width_ * warp_coeff_;
ICHECK_NE(factor, 0) << "Divide by zero";
warp_group_ = (alloc_size + (factor - 1)) / factor;
Expand All @@ -245,7 +266,65 @@ class WarpAccessRewriter : protected StmtExprMutator {
}

protected:

PrimExpr VisitExpr_(const CallNode* op) override {
if (op->op.same_as(builtin::ptx_mma())) {
Array<PrimExpr> new_args = op->args;
PrimExpr local_index, group;
bool changed = false;
// Arguments here look like [mma_name, Multiplicand A, offset/bias of A, Multiplicand
// B, offset/bias of B, Accumulator C, offset/bias of C]. Therefore 2,4,6 are the indices
// for offset/bias of Multiplicand A, B and Accumulator C for MMA instruction where 1,3,5
// are the indices for Multiplicand A, B and Accumulator C.
// Because this pass only process one buffer at a time, we need to make sure this function
// only changes the offset of the buffer being processed
int A_warp_arg_ind = 6;
for (int i = A_warp_arg_ind; i < A_warp_arg_ind + 6; i += 2) {
if (op->args[i].get() == buffer_) {
std::tie(local_index, group) = SplitIndexByGroup(op->args[i + 1]);
LOG(INFO) << "local index for " << op->args[i] << " : " << local_index;
new_args.Set(i + 1, local_index);
changed = true;
}
}
if (!changed) return GetRef<PrimExpr>(op);
return Call(op->dtype, op->op, new_args);
}
if (op->op.same_as(builtin::ptx_ldmatrix())) {
Array<PrimExpr> new_args = op->args;
PrimExpr local_index, group;
if (op->args[3].get() == buffer_) {
new_args.Set(4, 0);
return Call(op->dtype, op->op, new_args);
}
return GetRef<PrimExpr>(op);
}
if (op->op.same_as(builtin::call_extern()) &&
Downcast<StringImm>(op->args[0])->value == "mma_m16n8k8_row_row_fp16fp16fp32") {
Array<PrimExpr> new_args = op->args;
PrimExpr local_index, group;
bool changed = false;
// Arguments here look like [mma_name, Multiplicand A, offset/bias of A, Multiplicand
// B, offset/bias of B, Accumulator C, offset/bias of C]. Therefore 2,4,6 are the indices
// for offset/bias of Multiplicand A, B and Accumulator C for MMA instruction where 1,3,5
// are the indices for Multiplicand A, B and Accumulator C.
// Because this pass only process one buffer at a time, we need to make sure this function
// only changes the offset of the buffer being processed
for (int i = 1; i < 6; i += 2) {
if (op->args[i].get() == buffer_) {
std::tie(local_index, group) = SplitIndexByGroup(op->args[i + 1]);
new_args.Set(i + 1, local_index);
changed = true;
}
}
if (!changed) return GetRef<PrimExpr>(op);
return Call(op->dtype, op->op, new_args);
}
return StmtExprMutator::VisitExpr_(op);
}

PrimExpr VisitExpr_(const VarNode* op) override {
// LOG(INFO) << GetRef<Var>(op);
ICHECK(op != buffer_) << "Cannot access address of warp memory directly";
return StmtExprMutator::VisitExpr_(op);
}
Expand Down Expand Up @@ -292,6 +371,10 @@ class WarpAccessRewriter : protected StmtExprMutator {
PrimExpr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->indices[0]);
// invariance: local index must do not contain warp id
LOG(INFO) << "local_index = " << local_index;
LOG(INFO) << "group = " << group;
LOG(INFO) << "warp_index_ = " << warp_index_;

ICHECK(!UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); }))
<< "LowerWarpMemory failed to rewrite load to shuffle for index " << op->indices[0]
<< " local_index=" << local_index;
Expand Down Expand Up @@ -428,11 +511,13 @@ Pass LowerWarpMemory() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
LOG(INFO) << f;
ICHECK(target.defined()) << "LowerWarpMemory: Require the target attribute";
int warp_size = target.value()->GetAttr<Integer>("thread_warp_size", 1).value();
int warp_size = 32;
WarpMemoryRewriter warp_memory_rewriter(warp_size);
auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body));
n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt);
LOG(INFO) << "After lower warp: " << f;
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});
Expand Down

0 comments on commit 4cf6b20

Please sign in to comment.