From 0c859c4f385ba0b6f9477b569b80cee80b5b7282 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 26 Apr 2022 19:18:23 +0900 Subject: [PATCH] clean up --- src/tir/transforms/lower_warp_memory.cc | 57 +++---------------------- 1 file changed, 5 insertions(+), 52 deletions(-) diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 0c3f09b4277c..e9a8513391c5 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -150,7 +150,6 @@ class WarpStoreCoeffFinder : private StmtExprVisitor { } void UpdatePattern(const PrimExpr& index) { - LOG(INFO) << index; Array m = arith::DetectLinearEquation(index, {warp_index_}); ICHECK_EQ(m.size(), 2U) << "LowerWarpMemory failed. Could not simplify the store index `" << index @@ -166,7 +165,7 @@ class WarpStoreCoeffFinder : private StmtExprVisitor { if (warp_coeff_ != 0) { ICHECK_EQ(warp_coeff_, mcoeff_as_int->value) - << "LowerWarpMemory failed due to two different store coefficient to warp index" << warp_coeff_ << ", "<< mcoeff_as_int->value; + << "LowerWarpMemory failed due to two different store coefficient to warp index"; } else { warp_coeff_ = mcoeff_as_int->value; } @@ -183,7 +182,7 @@ class WarpStoreCoeffFinder : private StmtExprVisitor { }; // Visitor to find the warp index -class WarpIndexFinder : private StmtExprVisitor { +class WarpIndexFinder : private StmtVisitor { public: explicit WarpIndexFinder(int warp_size) : warp_size_(warp_size) {} // find the warp co-efficient and the shuffle width in the statement @@ -215,7 +214,6 @@ class WarpIndexFinder : private StmtExprVisitor { } else { width_ = value_as_int->value; warp_index_ = iv; - // LOG(INFO) << "Using warp index from " << GetRef(op); } } } @@ -236,26 +234,15 @@ class WarpAccessRewriter : protected StmtExprMutator { // Rewrite the allocate statement which transforms // warp memory to local memory. Stmt Rewrite(const AllocateNode* op) { - // LOG(INFO) << "Allocate " << GetRef(op); - // LOG(INFO) << op->buffer_var; buffer_ = op->buffer_var.get(); int alloc_size = op->ConstantAllocationSize(); ICHECK_GT(alloc_size, 0) << "warp memory only support constant alloc size"; alloc_size *= op->dtype.lanes(); std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body); - LOG(INFO) << warp_index_ << ", " << width_; - LOG(INFO) << op->buffer_var; warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body); - // if (warp_size_ == width_) { - // warp_coeff_ = 1; - // } else { - // LOG(INFO) << alloc_size << ", " << warp_size_ << ", " << width_; - // warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body); - // } // Align the local memory size. The number of elements may not // be a multiple of width_ * warp_coeff_; round it up. - LOG(INFO) << warp_coeff_; int factor = width_ * warp_coeff_; ICHECK_NE(factor, 0) << "Divide by zero"; warp_group_ = (alloc_size + (factor - 1)) / factor; @@ -266,23 +253,15 @@ class WarpAccessRewriter : protected StmtExprMutator { } protected: - PrimExpr VisitExpr_(const CallNode* op) override { if (op->op.same_as(builtin::ptx_mma())) { Array new_args = op->args; PrimExpr local_index, group; bool changed = false; - // Arguments here look like [mma_name, Multiplicand A, offset/bias of A, Multiplicand - // B, offset/bias of B, Accumulator C, offset/bias of C]. Therefore 2,4,6 are the indices - // for offset/bias of Multiplicand A, B and Accumulator C for MMA instruction where 1,3,5 - // are the indices for Multiplicand A, B and Accumulator C. - // Because this pass only process one buffer at a time, we need to make sure this function - // only changes the offset of the buffer being processed int A_warp_arg_ind = 6; for (int i = A_warp_arg_ind; i < A_warp_arg_ind + 6; i += 2) { if (op->args[i].get() == buffer_) { std::tie(local_index, group) = SplitIndexByGroup(op->args[i + 1]); - LOG(INFO) << "local index for " << op->args[i] << " : " << local_index; new_args.Set(i + 1, local_index); changed = true; } @@ -290,6 +269,7 @@ class WarpAccessRewriter : protected StmtExprMutator { if (!changed) return GetRef(op); return Call(op->dtype, op->op, new_args); } + if (op->op.same_as(builtin::ptx_ldmatrix())) { Array new_args = op->args; PrimExpr local_index, group; @@ -299,32 +279,11 @@ class WarpAccessRewriter : protected StmtExprMutator { } return GetRef(op); } - if (op->op.same_as(builtin::call_extern()) && - Downcast(op->args[0])->value == "mma_m16n8k8_row_row_fp16fp16fp32") { - Array new_args = op->args; - PrimExpr local_index, group; - bool changed = false; - // Arguments here look like [mma_name, Multiplicand A, offset/bias of A, Multiplicand - // B, offset/bias of B, Accumulator C, offset/bias of C]. Therefore 2,4,6 are the indices - // for offset/bias of Multiplicand A, B and Accumulator C for MMA instruction where 1,3,5 - // are the indices for Multiplicand A, B and Accumulator C. - // Because this pass only process one buffer at a time, we need to make sure this function - // only changes the offset of the buffer being processed - for (int i = 1; i < 6; i += 2) { - if (op->args[i].get() == buffer_) { - std::tie(local_index, group) = SplitIndexByGroup(op->args[i + 1]); - new_args.Set(i + 1, local_index); - changed = true; - } - } - if (!changed) return GetRef(op); - return Call(op->dtype, op->op, new_args); - } + return StmtExprMutator::VisitExpr_(op); } PrimExpr VisitExpr_(const VarNode* op) override { - // LOG(INFO) << GetRef(op); ICHECK(op != buffer_) << "Cannot access address of warp memory directly"; return StmtExprMutator::VisitExpr_(op); } @@ -371,10 +330,6 @@ class WarpAccessRewriter : protected StmtExprMutator { PrimExpr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->indices[0]); // invariance: local index must do not contain warp id - LOG(INFO) << "local_index = " << local_index; - LOG(INFO) << "group = " << group; - LOG(INFO) << "warp_index_ = " << warp_index_; - ICHECK(!UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); })) << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->indices[0] << " local_index=" << local_index; @@ -511,13 +466,11 @@ Pass LowerWarpMemory() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto target = f->GetAttr(tvm::attr::kTarget); - LOG(INFO) << f; ICHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; - int warp_size = 32; + int warp_size = target.value()->GetAttr("thread_warp_size", 1).value(); WarpMemoryRewriter warp_memory_rewriter(warp_size); auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body)); n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt); - LOG(INFO) << "After lower warp: " << f; return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});