Skip to content

Commit

Permalink
tuned int8 4k, 91 TOPS
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent 94d9d96 commit c2e314c
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 7 deletions.
3 changes: 1 addition & 2 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -821,8 +821,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {

if (trans && op->dtype.bits() == 8) {
std::string smem_stride = this->PrintExpr(op->args[6]);
LOG(INFO) << op->dtype;
CHECK(num == 4);
ICHECK(num == 4);
os << "for (int i = 0; i < 16; ++i) {\n";
os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr
<< "[(i % 8) / 4 * " + smem_stride + " * 16 + (threadIdx.x % 4) * 4 * " + smem_stride +
Expand Down
2 changes: 0 additions & 2 deletions src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -492,13 +492,11 @@ namespace transform {
Pass LowerWarpMemory() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
// LOG(INFO) << f;
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
int warp_size = 32;
WarpMemoryRewriter warp_memory_rewriter(warp_size);
auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body));
n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt);
LOG(INFO) << f;
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});
Expand Down
6 changes: 3 additions & 3 deletions tests/python/unittest/test_mma_16x8x32_4k_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,9 @@ def schedule(sch: tir.Schedule):
k_factors = sch.sample_perfect_tile(k, n=3)
num_ty = sch.get(i_factors[2]) * sch.get(j_factors[2])
else:
i_factors = [4, 8, 2, 4, 1]
j_factors = [1, 64, 2, 1, 2]
k_factors = [64, 2, 1]
i_factors = [1, 32, 1, 4, 2]
j_factors = [8, 4, 4, 2, 1]
k_factors = [32, 2, 2]

num_ty = i_factors[2] * j_factors[2]

Expand Down

0 comments on commit c2e314c

Please sign in to comment.