Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent f6aadbf commit 0c859c4
Showing 1 changed file with 5 additions and 52 deletions.
57 changes: 5 additions & 52 deletions src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ class WarpStoreCoeffFinder : private StmtExprVisitor {
}

void UpdatePattern(const PrimExpr& index) {
LOG(INFO) << index;
Array<PrimExpr> m = arith::DetectLinearEquation(index, {warp_index_});
ICHECK_EQ(m.size(), 2U)
<< "LowerWarpMemory failed. Could not simplify the store index `" << index
Expand All @@ -166,7 +165,7 @@ class WarpStoreCoeffFinder : private StmtExprVisitor {

if (warp_coeff_ != 0) {
ICHECK_EQ(warp_coeff_, mcoeff_as_int->value)
<< "LowerWarpMemory failed due to two different store coefficient to warp index" << warp_coeff_ << ", "<< mcoeff_as_int->value;
<< "LowerWarpMemory failed due to two different store coefficient to warp index";
} else {
warp_coeff_ = mcoeff_as_int->value;
}
Expand All @@ -183,7 +182,7 @@ class WarpStoreCoeffFinder : private StmtExprVisitor {
};

// Visitor to find the warp index
class WarpIndexFinder : private StmtExprVisitor {
class WarpIndexFinder : private StmtVisitor {
public:
explicit WarpIndexFinder(int warp_size) : warp_size_(warp_size) {}
// find the warp co-efficient and the shuffle width in the statement
Expand Down Expand Up @@ -215,7 +214,6 @@ class WarpIndexFinder : private StmtExprVisitor {
} else {
width_ = value_as_int->value;
warp_index_ = iv;
// LOG(INFO) << "Using warp index from " << GetRef<AttrStmt>(op);
}
}
}
Expand All @@ -236,26 +234,15 @@ class WarpAccessRewriter : protected StmtExprMutator {
// Rewrite the allocate statement which transforms
// warp memory to local memory.
Stmt Rewrite(const AllocateNode* op) {
// LOG(INFO) << "Allocate " << GetRef<Allocate>(op);
// LOG(INFO) << op->buffer_var;
buffer_ = op->buffer_var.get();
int alloc_size = op->ConstantAllocationSize();
ICHECK_GT(alloc_size, 0) << "warp memory only support constant alloc size";
alloc_size *= op->dtype.lanes();
std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body);
LOG(INFO) << warp_index_ << ", " << width_;
LOG(INFO) << op->buffer_var;
warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body);
// if (warp_size_ == width_) {
// warp_coeff_ = 1;
// } else {
// LOG(INFO) << alloc_size << ", " << warp_size_ << ", " << width_;
// warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body);
// }

// Align the local memory size. The number of elements may not
// be a multiple of width_ * warp_coeff_; round it up.
LOG(INFO) << warp_coeff_;
int factor = width_ * warp_coeff_;
ICHECK_NE(factor, 0) << "Divide by zero";
warp_group_ = (alloc_size + (factor - 1)) / factor;
Expand All @@ -266,30 +253,23 @@ class WarpAccessRewriter : protected StmtExprMutator {
}

protected:

PrimExpr VisitExpr_(const CallNode* op) override {
if (op->op.same_as(builtin::ptx_mma())) {
Array<PrimExpr> new_args = op->args;
PrimExpr local_index, group;
bool changed = false;
// Arguments here look like [mma_name, Multiplicand A, offset/bias of A, Multiplicand
// B, offset/bias of B, Accumulator C, offset/bias of C]. Therefore 2,4,6 are the indices
// for offset/bias of Multiplicand A, B and Accumulator C for MMA instruction where 1,3,5
// are the indices for Multiplicand A, B and Accumulator C.
// Because this pass only process one buffer at a time, we need to make sure this function
// only changes the offset of the buffer being processed
int A_warp_arg_ind = 6;
for (int i = A_warp_arg_ind; i < A_warp_arg_ind + 6; i += 2) {
if (op->args[i].get() == buffer_) {
std::tie(local_index, group) = SplitIndexByGroup(op->args[i + 1]);
LOG(INFO) << "local index for " << op->args[i] << " : " << local_index;
new_args.Set(i + 1, local_index);
changed = true;
}
}
if (!changed) return GetRef<PrimExpr>(op);
return Call(op->dtype, op->op, new_args);
}

if (op->op.same_as(builtin::ptx_ldmatrix())) {
Array<PrimExpr> new_args = op->args;
PrimExpr local_index, group;
Expand All @@ -299,32 +279,11 @@ class WarpAccessRewriter : protected StmtExprMutator {
}
return GetRef<PrimExpr>(op);
}
if (op->op.same_as(builtin::call_extern()) &&
Downcast<StringImm>(op->args[0])->value == "mma_m16n8k8_row_row_fp16fp16fp32") {
Array<PrimExpr> new_args = op->args;
PrimExpr local_index, group;
bool changed = false;
// Arguments here look like [mma_name, Multiplicand A, offset/bias of A, Multiplicand
// B, offset/bias of B, Accumulator C, offset/bias of C]. Therefore 2,4,6 are the indices
// for offset/bias of Multiplicand A, B and Accumulator C for MMA instruction where 1,3,5
// are the indices for Multiplicand A, B and Accumulator C.
// Because this pass only process one buffer at a time, we need to make sure this function
// only changes the offset of the buffer being processed
for (int i = 1; i < 6; i += 2) {
if (op->args[i].get() == buffer_) {
std::tie(local_index, group) = SplitIndexByGroup(op->args[i + 1]);
new_args.Set(i + 1, local_index);
changed = true;
}
}
if (!changed) return GetRef<PrimExpr>(op);
return Call(op->dtype, op->op, new_args);
}

return StmtExprMutator::VisitExpr_(op);
}

PrimExpr VisitExpr_(const VarNode* op) override {
// LOG(INFO) << GetRef<Var>(op);
ICHECK(op != buffer_) << "Cannot access address of warp memory directly";
return StmtExprMutator::VisitExpr_(op);
}
Expand Down Expand Up @@ -371,10 +330,6 @@ class WarpAccessRewriter : protected StmtExprMutator {
PrimExpr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->indices[0]);
// invariance: local index must do not contain warp id
LOG(INFO) << "local_index = " << local_index;
LOG(INFO) << "group = " << group;
LOG(INFO) << "warp_index_ = " << warp_index_;

ICHECK(!UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); }))
<< "LowerWarpMemory failed to rewrite load to shuffle for index " << op->indices[0]
<< " local_index=" << local_index;
Expand Down Expand Up @@ -511,13 +466,11 @@ Pass LowerWarpMemory() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
LOG(INFO) << f;
ICHECK(target.defined()) << "LowerWarpMemory: Require the target attribute";
int warp_size = 32;
int warp_size = target.value()->GetAttr<Integer>("thread_warp_size", 1).value();
WarpMemoryRewriter warp_memory_rewriter(warp_size);
auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body));
n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt);
LOG(INFO) << "After lower warp: " << f;
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});
Expand Down

0 comments on commit 0c859c4

Please sign in to comment.