From d0c3fc93c1e11890479a67d64a8aa6868383bbeb Mon Sep 17 00:00:00 2001 From: shmsong Date: Wed, 21 Sep 2022 15:45:54 -0700 Subject: [PATCH 01/10] pipe through cpasyncCG --- torch/csrc/jit/codegen/cuda/codegen.cpp | 22 +++-- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 3 +- .../jit/codegen/cuda/lower_validation.cpp | 1 + torch/csrc/jit/codegen/cuda/runtime/memory.cu | 92 +++++++++++++++++++ .../jit/codegen/cuda/scheduler/matmul.cpp | 2 +- torch/csrc/jit/codegen/cuda/type.cpp | 2 + torch/csrc/jit/codegen/cuda/type.h | 2 +- 7 files changed, 114 insertions(+), 10 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 4dd7f07c40e70..0f88ab94f4462 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -617,18 +617,25 @@ class CudaKernelGenerator : private OptOutConstDispatch { // Utility function to emit a cp.async intrinsic void genCpAsync(const LoadStoreOp* ldst, int vec_size) { auto dtype = ldst->in()->getDataType().value(); + bool is_cg = ldst->opType() == LoadStoreOpType::CpAsyncCg; + + if (is_cg) { + indent() << "Ampere::cpAsyncCg"; + } else { + indent() << "Ampere::cpAsync"; + } if (ldst->predicate() == nullptr) { // Out of line predicate variant - indent() << "Ampere::cpAsync<" << dtype << "," << vec_size << ">(" - << genMaybeHoistedPointer(ldst->out()) << "," - << genMaybeHoistedPointer(ldst->in()) << ");\n"; + code_ << "<" << dtype << "," << vec_size << ">(" + << genMaybeHoistedPointer(ldst->out()) << "," + << genMaybeHoistedPointer(ldst->in()) << ");\n"; } else { // Inline predicate variant - indent() << "Ampere::cpAsync<" << dtype << "," << vec_size << ">(" - << genMaybeHoistedPointer(ldst->out()) << "," - << genMaybeHoistedPointer(ldst->in()) << "," - << genInline(ldst->predicate()) << ");\n"; + code_ << "<" << dtype << "," << vec_size << ">(" + << genMaybeHoistedPointer(ldst->out()) << "," + << genMaybeHoistedPointer(ldst->in()) << "," + << genInline(ldst->predicate()) << ");\n"; } } @@ -1402,6 +1409,7 @@ class CudaKernelGenerator : private OptOutConstDispatch { genLdMatrix(ldst, vector_word_size); break; case LoadStoreOpType::CpAsync: + case LoadStoreOpType::CpAsyncCg: genCpAsync(ldst, vector_word_size); break; default: diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 0bb82bfbee6b0..cb2b55a4e984b 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -142,7 +142,8 @@ bool isLdMatrixOp(const Expr* expr) { bool isCpAsyncOp(const Expr* expr) { if (auto ldst = dynamic_cast(expr)) { - return ldst->opType() == LoadStoreOpType::CpAsync; + return ldst->opType() == LoadStoreOpType::CpAsync || + ldst->opType() == LoadStoreOpType::CpAsyncCg; } return false; } diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 38c12aa2d3f0a..f23ebef61c181 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -1016,6 +1016,7 @@ void validateArchMemoryOp(LoadStoreOp* ldst) { validateLdMatrixOutput(ldst->out()->as()); return; case LoadStoreOpType::CpAsync: + case LoadStoreOpType::CpAsyncCg: validateMinimumArch(8, 0); return; default: diff --git a/torch/csrc/jit/codegen/cuda/runtime/memory.cu b/torch/csrc/jit/codegen/cuda/runtime/memory.cu index 38bacb3d6a1fc..b5c0c7372305d 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/memory.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/memory.cu @@ -271,6 +271,98 @@ DEVICE_INLINE void cpAsync( "r"((int)predicate)); } +// Global to SMEM load that is asynchronous, +// The cache global variant, i.e. skip L1 caching. +// more details see: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators +// not guaranteed to be completed until cpAsyncBarrier() is called. +template +DEVICE_INLINE void cpAsyncCg(void* smem_ptr, void const* gmem_ptr) { + unsigned smem_addr = util::toSmem(smem_ptr); + constexpr int byte_size = sizeof(dtype) * len; + + static_assert( + byte_size == 4 || byte_size == 8 || byte_size == 16, + "cp_async : unsupported byte size"); + + asm volatile( + "cp.async.cg.shared.global [%0], [%1], %2;\n" ::"r"(smem_addr), + "l"(gmem_ptr), + "n"(byte_size)); +} + +// Global to SMEM load that is asynchronous, +// not guaranteed to be completed until cpAsyncBarrier() is called. +template +DEVICE_INLINE void cpAsyncCg( + void* smem_ptr, + void const* gmem_ptr, + bool predicate) { + unsigned smem_addr = util::toSmem(smem_ptr); + constexpr int byte_size = sizeof(dtype) * len; + + static_assert( + byte_size == 4 || byte_size == 8 || byte_size == 16, + "cp_async : unsupported byte size"); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %3, 0;\n" + "@p cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem_addr), + "l"(gmem_ptr), + "n"(byte_size), + "r"((int)predicate)); +} + +// cp.async +// This is the variant that supports lifted indexing +template +DEVICE_INLINE void cpAsyncCg( + nvfuser_index_t smem_index, + unsigned smem_addr, + nvfuser_index_t gmem_index, + DataPointer& gmem_ptr) { + constexpr int byte_size = sizeof(dtype) * len; + + static_assert( + byte_size == 4 || byte_size == 8 || byte_size == 16, + "cp_async : unsupported byte size"); + + asm volatile( + "cp.async.cg.shared.global [%0], [%1], %2;\n" ::"r"( + smem_addr + (unsigned)smem_index), + "l"(gmem_ptr + gmem_index), + "n"(byte_size)); +} + +// cp.async +// This is the variant that supports lifted indexing, with predicate inlined. +template +DEVICE_INLINE void cpAsyncCg( + nvfuser_index_t smem_index, + unsigned smem_addr, + nvfuser_index_t gmem_index, + DataPointer& gmem_ptr, + bool predicate) { + constexpr int byte_size = sizeof(dtype) * len; + + static_assert( + byte_size == 4 || byte_size == 8 || byte_size == 16, + "cp_async : unsupported byte size"); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %3, 0;\n" + "@p cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem_addr + (unsigned)smem_index), + "l"(gmem_ptr + gmem_index), + "n"(byte_size), + "r"((int)predicate)); +} + // TODO: Might have a different category of sync if we want to build out this: DEVICE_INLINE void cpAsyncBarrier() { asm volatile("cp.async.wait_all;"); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp index 5c6f9d9eb33f7..7435f91bedd83 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp @@ -370,7 +370,7 @@ void scheduleMatmul( // Use cp.async as requested in scheduler params. c10::optional load_op = c10::nullopt; if (params.async_gmem_load_operands) { - load_op = LoadStoreOpType::CpAsync; + load_op = LoadStoreOpType::CpAsyncCg; } acw_smem = ar->cacheAfter(load_op); diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index 339b9825adaa2..8840a59193da8 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -808,6 +808,8 @@ static const char* load_store_type2string(LoadStoreOpType t) { return "LdMatrixTranspose"; case LoadStoreOpType::CpAsync: return "CpAsync"; + case LoadStoreOpType::CpAsyncCg: + return "CpAsyncCg"; default: TORCH_INTERNAL_ASSERT(false, "Unexpected parallel type"); } diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index 29d689559a769..c1d2d8c6e088c 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -323,7 +323,7 @@ static constexpr std::array kIdMappingModes = { // Used to annotate the special memory intrinsics that a loadstore // op will be lowered to. -enum class LoadStoreOpType { LdMatrix, LdMatrixTranspose, CpAsync }; +enum class LoadStoreOpType { LdMatrix, LdMatrixTranspose, CpAsync, CpAsyncCg }; // Used to label what part of the double buffered iterdomain // a for loop is materializing. From 660874427c796c8edebfafe174657f7a100f3942 Mon Sep 17 00:00:00 2001 From: shmsong Date: Thu, 8 Sep 2022 18:45:01 -0700 Subject: [PATCH 02/10] make shared mem alloc in place; tighten alias analysis --- .../jit/codegen/cuda/lower_alias_memory.cpp | 14 +++++++++++++- .../csrc/jit/codegen/cuda/lower_allocation.cpp | 17 +++++------------ 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp b/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp index 42102c143097e..03fc21fdddd3e 100644 --- a/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp @@ -706,7 +706,19 @@ class BufferUseDefInfo { for (const auto idx : c10::irange(current_stack_.size() - 1)) { if (current_stack_[idx] == allocate_loop_info) { - return current_stack_[idx + 1]; + auto return_candidate = idx + 1; + while (return_candidate < current_stack_.size()) { + if (!current_stack_[return_candidate] + ->loop->iter_domain() + ->isThread()) { + return current_stack_[return_candidate]; + } + return_candidate++; + } + + // This means there are only thread loops between allocate + // stack and the current stack. + return nullptr; } } diff --git a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp index 67f0b7cf30d7b..afa2e76a9c12b 100644 --- a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp @@ -510,18 +510,11 @@ class AllocationInserter : public kir::ExprMutator { // Register allocations before initializations to keep them in the right // order if (alloc_expr != nullptr) { - if (allocation.buffer->getMemoryType() == MemoryType::Shared) { - // Shared allocations go at the begining of scope - TORCH_INTERNAL_ASSERT(!exprs_.empty()); - registerInsertBefore(exprs_[0], alloc_expr, nullptr); - } else { - TORCH_INTERNAL_ASSERT(allocation.alloc_place_before != nullptr); - kir::Scope* scope = allocation.alloc_for_loop == nullptr - ? nullptr - : &allocation.alloc_for_loop->body(); - registerInsertBefore( - allocation.alloc_place_before, alloc_expr, scope); - } + TORCH_INTERNAL_ASSERT(allocation.alloc_place_before != nullptr); + kir::Scope* scope = allocation.alloc_for_loop == nullptr + ? nullptr + : &allocation.alloc_for_loop->body(); + registerInsertBefore(allocation.alloc_place_before, alloc_expr, scope); } if (init_expr != nullptr) { From 9c9b4174bfa77d772edb89a3f3f3c967b74740c3 Mon Sep 17 00:00:00 2001 From: shmsong Date: Thu, 8 Sep 2022 21:07:21 -0700 Subject: [PATCH 03/10] add deallocate node --- torch/csrc/jit/codegen/cuda/codegen.cpp | 12 ++++++++++++ torch/csrc/jit/codegen/cuda/dispatch.cpp | 15 +++++++++++++++ torch/csrc/jit/codegen/cuda/dispatch.h | 4 ++++ torch/csrc/jit/codegen/cuda/ir_iostream.cpp | 6 ++++++ torch/csrc/jit/codegen/cuda/ir_iostream.h | 1 + torch/csrc/jit/codegen/cuda/kernel.h | 14 ++++++++++++++ torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 13 +++++++++++++ torch/csrc/jit/codegen/cuda/kernel_ir.h | 13 +++++++++++++ torch/csrc/jit/codegen/cuda/lower_index.cpp | 5 +++++ torch/csrc/jit/codegen/cuda/lower_index.h | 1 + torch/csrc/jit/codegen/cuda/mutator.cpp | 3 +++ torch/csrc/jit/codegen/cuda/type.h | 1 + 12 files changed, 88 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 0f88ab94f4462..90c16233bd386 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -2626,6 +2626,18 @@ class CudaKernelGenerator : private OptOutConstDispatch { indent() << sync_call << ";\n"; } + void handle(const kir::Deallocate* dealloc) final { + const auto alloc = dealloc->buffer(); + const auto tv = alloc->buffer()->as(); + const auto size = alloc->size(); + const auto buffer_dtype = alloc->buffer()->dtype(); + + TORCH_INTERNAL_ASSERT(size != nullptr); + indent() << "// de-alloc " << varName(tv) << "\n"; + indent() << "offset -= (" << genInline(size) << " * sizeof(" << buffer_dtype + << "));\n"; + } + void handle(const kir::InitMagicZero*) final { indent() << "NVFUSER_DEFINE_MAGIC_ZERO\n"; } diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index 766987638a69c..fdf6f8ea31a90 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -154,6 +154,9 @@ void Expr::dispatch(T handler, Expr* expr) { case ExprType::Allocate: ptr(handler)->handle(expr->as()); return; + case ExprType::DeAllocate: + ptr(handler)->handle(expr->as()); + return; case ExprType::BlockSync: ptr(handler)->handle(expr->as()); return; @@ -331,6 +334,9 @@ void Expr::constDispatch(T handler, const Expr* expr) { case ExprType::Allocate: ptr(handler)->handle(expr->as()); return; + case ExprType::DeAllocate: + ptr(handler)->handle(expr->as()); + return; case ExprType::BlockSync: ptr(handler)->handle(expr->as()); return; @@ -516,6 +522,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) { case ExprType::Allocate: ptr(mutator)->mutate(expr->as()); return; + case ExprType::DeAllocate: + ptr(mutator)->mutate(expr->as()); + return; case ExprType::BlockSync: ptr(mutator)->mutate(expr->as()); return; @@ -766,6 +775,9 @@ void OptOutConstDispatch::handle(const ViewOp* stmt) { void OptOutConstDispatch::handle(const kir::Allocate* stmt) { unhandled(stmt); } +void OptOutConstDispatch::handle(const kir::Deallocate* stmt) { + unhandled(stmt); +} void OptOutConstDispatch::handle(const kir::BlockSync* stmt) { unhandled(stmt); } @@ -913,6 +925,9 @@ void OptOutDispatch::handle(ViewOp* stmt) { void OptOutDispatch::handle(kir::Allocate* stmt) { unhandled(stmt); } +void OptOutDispatch::handle(kir::Deallocate* stmt) { + unhandled(stmt); +} void OptOutDispatch::handle(kir::BlockSync* stmt) { unhandled(stmt); } diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index 828bc8affa8dc..170f9d4093762 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -95,6 +95,7 @@ class TensorIndex; class IntPair; class Allocate; +class Deallocate; class BlockSync; class GridSync; class CpAsyncWait; @@ -162,6 +163,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase { virtual void handle(const ViewOp* stmt); virtual void handle(const kir::Allocate*); + virtual void handle(const kir::Deallocate*); virtual void handle(const kir::BlockSync*); virtual void handle(const kir::GridSync*); virtual void handle(const kir::CpAsyncWait*); @@ -226,6 +228,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { virtual void handle(ViewOp* stmt); virtual void handle(kir::Allocate* stmt); + virtual void handle(kir::Deallocate* stmt); virtual void handle(kir::BlockSync* stmt); virtual void handle(kir::GridSync* stmt); virtual void handle(kir::CpAsyncWait* stmt); @@ -331,6 +334,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { virtual void mutate(ViewOp*); virtual void mutate(kir::Allocate*); + virtual void mutate(kir::Deallocate*); virtual void mutate(kir::BlockSync*); virtual void mutate(kir::GridSync*); virtual void mutate(kir::CpAsyncWait*); diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index c81b9c7ab8130..44253677c78b4 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -629,6 +629,12 @@ void IrPrinter::handle(const kir::GridSync* node) { os_ << ")\n"; } +void IrPrinter::handle(const kir::Deallocate* node) { + indent() << "DeAllocate("; + handle(node->buffer()); + os_ << ")\n"; +} + void IrPrinter::handle(const kir::ForLoop* node) { indent() << "FOR "; handle(node->index()); diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h index ceadbe398c255..1a287a0c21f7d 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.h +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h @@ -109,6 +109,7 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch { void handle(const kir::ForLoop*) final; void handle(const kir::IfThenElse*) final; void handle(const kir::Allocate*) final; + void handle(const kir::Deallocate*) final; void handle(const kir::BlockSync*) final; void handle(const kir::GridSync*) final; void handle(const kir::CpAsyncWait*) final; diff --git a/torch/csrc/jit/codegen/cuda/kernel.h b/torch/csrc/jit/codegen/cuda/kernel.h index e2a0e57ed68f6..3958a8eae1978 100644 --- a/torch/csrc/jit/codegen/cuda/kernel.h +++ b/torch/csrc/jit/codegen/cuda/kernel.h @@ -22,6 +22,17 @@ namespace fuser { namespace cuda { namespace kir { +// Smem allocation-deallocation trace: +enum class SmemAllocAction { Allocate = 0, DeAllocate }; + +struct AllocateRecord { + // Vector of allocate or deallocate actions in a kernel + std::vector actions; + + // Vector of sizes in corresponding allocations. + std::vector sizes; +}; + //! Summary of interesting facts about the kernel // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct KernelSummary { @@ -96,6 +107,9 @@ struct KernelSummary { //! Track information on vectorized set operations for runtime validation std::vector vectorized_set_info; + + //! Records allocation and deallocation of dynamic shared memory space. + AllocateRecord allocation_record; }; class TORCH_CUDA_CU_API KernelPerformanceProfile { diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index e78e0681b6a77..ec085b8fcb35f 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -592,6 +592,19 @@ Allocate::Allocate( "IR type only valid for Kernel container."); } +Deallocate::Deallocate(IrBuilderPasskey passkey, Allocate* buffer) + : Expr(passkey, ExprType::DeAllocate), buffer_(buffer) { + TORCH_INTERNAL_ASSERT( + buffer->buffer()->isA() && + // Only shared mem is supported for now. + buffer->buffer()->as()->getMemoryType() == + MemoryType::Shared); + + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} + GridReduction::GridReduction( IrBuilderPasskey passkey, BinaryOpType reduction_op_type, diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 32c4d279b2cc1..7d8bdd1ab0612 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -270,6 +270,19 @@ class TORCH_CUDA_CU_API Allocate final : public Expr { const Allocate* alias_ = nullptr; }; +//! Deallocate a space that has been occupied by a buffer. +class TORCH_CUDA_CU_API Deallocate final : public Expr { + public: + explicit Deallocate(IrBuilderPasskey passkey, Allocate* buffer); + + auto buffer() const { + return buffer_; + } + + private: + const Allocate* buffer_ = nullptr; +}; + // Sync represents __syncthreads barrier for block level coordination. // // TODO(kir): change name to SyncThreads as we could have other barriers. diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index e9645a9249fec..6f0040fc3ceb9 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -873,6 +873,11 @@ void IndexLowering::handle(const kir::GridSync* sync) { pushBack(const_cast(sync)); // NOLINT } +void IndexLowering::handle(const kir::Deallocate* sync) { + // TODO(kir): remove the need for const_cast + pushBack(const_cast(sync)); // NOLINT +} + void IndexLowering::handle(const kir::CpAsyncWait* wait) { // TODO(kir): remove the need for const_cast pushBack(const_cast(wait)); // NOLINT diff --git a/torch/csrc/jit/codegen/cuda/lower_index.h b/torch/csrc/jit/codegen/cuda/lower_index.h index 593f416005a61..bb743e730579b 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.h +++ b/torch/csrc/jit/codegen/cuda/lower_index.h @@ -55,6 +55,7 @@ class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch { void handle(const kir::ForLoop*) final; void handle(const kir::IfThenElse*) final; void handle(const kir::Allocate*) final; + void handle(const kir::Deallocate*) final; void handle(const kir::BlockSync*) final; void handle(const kir::GridSync*) final; void handle(const kir::CpAsyncWait*) final; diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index 05fcb5aa9e1a0..91179cb107996 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -468,6 +468,9 @@ void OptOutMutator::mutate(Swizzle2D* m) { void OptOutMutator::mutate(kir::Allocate*) { TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); } +void OptOutMutator::mutate(kir::Deallocate*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} void OptOutMutator::mutate(kir::BlockSync*) { TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); } diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index c1d2d8c6e088c..136dcfa2bd538 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -127,6 +127,7 @@ enum class ExprType { Swizzle2DInt, PairSelect, Allocate, + DeAllocate, BlockSync, GridSync, CpAsyncWait, From 00dc1d6499712b445f086eebea3fcd482b1494c8 Mon Sep 17 00:00:00 2001 From: shmsong Date: Thu, 8 Sep 2022 22:13:13 -0700 Subject: [PATCH 04/10] add deallocation pass --- .../jit/codegen/cuda/lower_alias_memory.cpp | 115 +++++++++++++++++- 1 file changed, 109 insertions(+), 6 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp b/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp index 03fc21fdddd3e..24ba83ece4702 100644 --- a/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp @@ -864,9 +864,8 @@ void BufferReuseDebugPrinter::printAllocInfo(const kir::Allocate* alloc) { //! Reuse Allocation nodes via pointer aliasing class AllocateReuseModifier { public: - static void modify(const std::vector& exprs) { - AllocateReuseModifier modifier(exprs); - } + explicit AllocateReuseModifier(const std::vector& exprs) + : AllocateReuseModifier(exprs, nullptr) {} static void debugPrint(const std::vector& exprs) { BufferReuseDebugPrinter debug_printer; @@ -874,10 +873,14 @@ class AllocateReuseModifier { std::cout << debug_printer.dumpDebugInfo(); } + const auto& bufferInfo() const { + return buffer_info_; + } + private: AllocateReuseModifier( const std::vector& exprs, - BufferReuseDebugPrinter* debug_printer_ = nullptr) + BufferReuseDebugPrinter* debug_printer_) : buffer_info_(exprs, debug_printer_) { // Perform in-place sharing first and then outer liveness // based sharing. Since outer liveness info can still @@ -1217,6 +1220,106 @@ class AllocateReuseModifier { bool inner_aliasing_pass_ = true; }; +//! Insert deallocation if enabled to save memory +class InsertDeallocate : kir::ExprMutator { + public: + static std::vector run( + const BufferUseDefInfo& use_def, + std::vector expr) { + InsertDeallocate inserter(use_def, expr); + return inserter.exprs_; + } + + private: + explicit InsertDeallocate( + const BufferUseDefInfo& use_def, + std::vector expr) + : use_def_info_(use_def) { + traverseAndInsert(expr); + } + + using kir::ExprMutator::handle; + + void handle(kir::Allocate* alloc) final { + if ( + // No need to do anything to aliased allocation + alloc->alias() || + // Ignore scalar allocation + !alloc->buffer()->isA() || + // Only deallocate shared memory at the moment. + alloc->buffer()->as()->getMemoryType() != + MemoryType::Shared) { + return; + } + + // Try to garbage collect any expired memory allocation. + bool deallocated = tryDeallocateAt(alloc); + if (deallocated) { + // Insert a sync thread before this allocation + // for buffer reuse safety. + auto block_sync = IrBuilder::create(); + registerInsertAfter(alloc, block_sync); + } + + // put this alloc on the top of stack. + allocate_stack_.push_back(alloc); + } + + bool tryDeallocateAt(kir::Allocate* alloc) { + auto maybe_buffer_info = use_def_info_.getMaybeReuseInfoFor(alloc); + if (!maybe_buffer_info.has_value()) { + // No info on this allocation. Skip. + return false; + } + + auto& buffer_info = maybe_buffer_info.value(); + auto alloc_pos = buffer_info->alloc_pos; + + // status bit keeping track of if deallocation has + // been inserted. + bool deallocated = false; + while (!allocate_stack_.empty()) { + auto top_allocate = allocate_stack_.back(); + auto maybe_top_info = use_def_info_.getMaybeReuseInfoFor(top_allocate); + if (!maybe_top_info.has_value()) { + // Here we are stuck with an smem allocation + // that we have no use-def info for, and in this + // case we just abort any deallocation effort. + return deallocated; + } + + auto& top_info = maybe_top_info.value(); + + // Check outer live interval + if (top_info->outer_live_interval->lastRead() > alloc_pos) { + return deallocated; + } + + // Check aliased live interval + for (auto& outer_live_interval : *(top_info->outer_subscribed_intevals)) { + if (outer_live_interval->lastRead() > alloc_pos) { + return deallocated; + } + } + + // If we pass the above checks, we should be able to deallocate + // the top allocation of the stack. + auto deallocate = + IrBuilder::create(allocate_stack_.back()); + + registerInsertBefore(alloc, deallocate); + allocate_stack_.pop_back(); + deallocated = true; + } + + return deallocated; + } + + private: + const BufferUseDefInfo& use_def_info_; + std::vector allocate_stack_; +}; + } // namespace std::vector reuseMemoryAllocations(const std::vector& exprs) { @@ -1225,8 +1328,8 @@ std::vector reuseMemoryAllocations(const std::vector& exprs) { if (debug_print) { AllocateReuseModifier::debugPrint(exprs); } - AllocateReuseModifier::modify(exprs); - return exprs; + AllocateReuseModifier modifier(exprs); + return InsertDeallocate::run(modifier.bufferInfo(), exprs); } } // namespace cuda From 2c50933f455301ea57ada41dd9e864f7500c32dc Mon Sep 17 00:00:00 2001 From: shmsong Date: Thu, 8 Sep 2022 22:15:51 -0700 Subject: [PATCH 05/10] add fixme --- torch/csrc/jit/codegen/cuda/lower_allocation.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp index afa2e76a9c12b..3ce0c985e1787 100644 --- a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp @@ -510,6 +510,13 @@ class AllocationInserter : public kir::ExprMutator { // Register allocations before initializations to keep them in the right // order if (alloc_expr != nullptr) { + // TODO && FIXME: + // Shared memory allocations are now done in place, and they'd + // need to be lifted to the outermost serial loop so we don't get + // multiple increment. + // TODO: + // even better, should we just allocate and precompute offset + // for each smem tv instead of inplace incrementing? TORCH_INTERNAL_ASSERT(allocation.alloc_place_before != nullptr); kir::Scope* scope = allocation.alloc_for_loop == nullptr ? nullptr From 48166a985e2c8d54da45710e88d1eb107cb1f8c7 Mon Sep 17 00:00:00 2001 From: shmsong Date: Thu, 8 Sep 2022 22:46:05 -0700 Subject: [PATCH 06/10] pipe through deallocation to executor --- torch/csrc/jit/codegen/cuda/executor.cpp | 45 +++++++++++++++++-- torch/csrc/jit/codegen/cuda/executor.h | 6 +++ torch/csrc/jit/codegen/cuda/kernel.cpp | 12 +++++ torch/csrc/jit/codegen/cuda/kernel.h | 3 ++ .../jit/codegen/cuda/lower_allocation.cpp | 2 +- 5 files changed, 64 insertions(+), 4 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index 5f8a69dd22516..f1d49276a439d 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -442,6 +442,46 @@ uint64_t FusionExecutor::computeSharedMemory( return total; } +uint64_t FusionExecutor::computeDynamicSharedMemory( + kir::ExpressionEvaluator& expr_eval, + const kir::AllocateRecord& records, + uint64_t total) { + FUSER_PERF_SCOPE("computeDynamicSharedMemory"); + uint64_t max_total = total; + int N = records.actions.size(); + + for (int idx : c10::irange(N)) { + auto size = records.sizes[idx]; + auto alloc = records.allocs[idx]; + + // If this buffer aliases another buffer, + // then do not allocate memory for this buffer. + const auto inferred_val = expr_eval.evaluate(size); + TORCH_INTERNAL_ASSERT( + inferred_val.has_value(), + "Failed to evaluate the size ", + size, + " of shared memory buffer ", + alloc->buffer()->toString()); + + const uint64_t data_size = dataTypeSize(alloc->buffer()->dtype()); + const uint64_t buffer_size = data_size * inferred_val.value(); + + if (records.actions[idx] == kir::SmemAllocAction::Allocate) { + // Allocate: + const int align_size = 16; // always align to 16B/128b. + total = ceilDiv(total, align_size) * align_size; + total += buffer_size; + max_total = std::max(total, max_total); + } else { + // Deallocate: + total -= buffer_size; + } + } + + return max_total; +} + LaunchParams FusionExecutor::computeLaunchParams( const LaunchParams& launch_constraints, kir::ExpressionEvaluator& expr_eval, @@ -615,10 +655,9 @@ LaunchParams FusionExecutor::computeLaunchParams( } // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - const uint64_t dynamic_smem_size = computeSharedMemory( + const uint64_t dynamic_smem_size = computeDynamicSharedMemory( expr_eval, - kernel_summary.dynamic_smem_allocations, - true, + kernel_summary.allocation_record, reduction_broadcast_workspace); // Check that requested smem size can be dynamically allocated. diff --git a/torch/csrc/jit/codegen/cuda/executor.h b/torch/csrc/jit/codegen/cuda/executor.h index 8a56fe957fb8b..482b6c6d3e3b3 100644 --- a/torch/csrc/jit/codegen/cuda/executor.h +++ b/torch/csrc/jit/codegen/cuda/executor.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -179,6 +180,11 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable { bool align_padding = false, uint64_t total = 0); + uint64_t computeDynamicSharedMemory( + kir::ExpressionEvaluator& expr_eval, + const kir::AllocateRecord& records, + uint64_t total = 0); + // return a pair of vector of tensors, where tensors in the first vector are // not initialized, while the second vector contains zero-initiliazed tensors GlobalBuffers allocGlobalVals(kir::ExpressionEvaluator& expr_eval); diff --git a/torch/csrc/jit/codegen/cuda/kernel.cpp b/torch/csrc/jit/codegen/cuda/kernel.cpp index 3f5efe02d8ed6..76d9f72f04477 100644 --- a/torch/csrc/jit/codegen/cuda/kernel.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel.cpp @@ -70,6 +70,12 @@ class KernelIrScanner : private IrVisitor { break; case MemoryType::Shared: summary_.dynamic_smem_allocations.push_back(allocate); + if (!allocate->alias()) { + summary_.allocation_record.actions.push_back( + SmemAllocAction::Allocate); + summary_.allocation_record.sizes.push_back(allocate->size()); + summary_.allocation_record.allocs.push_back(allocate); + } break; case MemoryType::Local: if (!ExpressionEvaluator::isConst(allocate->size())) { @@ -80,6 +86,12 @@ class KernelIrScanner : private IrVisitor { } } + void handle(kir::Deallocate* deallocate) final { + summary_.allocation_record.actions.push_back(SmemAllocAction::DeAllocate); + summary_.allocation_record.sizes.push_back(deallocate->buffer()->size()); + summary_.allocation_record.allocs.push_back(deallocate->buffer()); + } + void handle(UnaryOp* unary_op) final { if (unary_op->getUnaryOpType() == UnaryOpType::RandLike) { summary_.max_rng_offsets = diff --git a/torch/csrc/jit/codegen/cuda/kernel.h b/torch/csrc/jit/codegen/cuda/kernel.h index 3958a8eae1978..11bef95591d5d 100644 --- a/torch/csrc/jit/codegen/cuda/kernel.h +++ b/torch/csrc/jit/codegen/cuda/kernel.h @@ -31,6 +31,9 @@ struct AllocateRecord { // Vector of sizes in corresponding allocations. std::vector sizes; + + // Vector of alloc expressions. + std::vector allocs; }; //! Summary of interesting facts about the kernel diff --git a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp index 3ce0c985e1787..1b3ef98680270 100644 --- a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp @@ -516,7 +516,7 @@ class AllocationInserter : public kir::ExprMutator { // multiple increment. // TODO: // even better, should we just allocate and precompute offset - // for each smem tv instead of inplace incrementing? + // for each smem tv instead of inplace incrementing? TORCH_INTERNAL_ASSERT(allocation.alloc_place_before != nullptr); kir::Scope* scope = allocation.alloc_for_loop == nullptr ? nullptr From c015c56ed6435029a1387ff209907542982aea57 Mon Sep 17 00:00:00 2001 From: shmsong Date: Fri, 9 Sep 2022 09:58:33 -0700 Subject: [PATCH 07/10] prototype epilog schedule --- torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 10 +++--- .../jit/codegen/cuda/lower_double_buffer.cpp | 24 +++++++------- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 4 +++ .../jit/codegen/cuda/scheduler/matmul.cpp | 33 +++++++++++++++++-- 4 files changed, 52 insertions(+), 19 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 55412e1816d97..be2cd10372cdb 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -1402,11 +1402,11 @@ void IterDomain::parallelize(ParallelType t) { // to make copies of the iterdomains. We might eventually just want // to lock these parallel types and not allowing any changes once // they are swizzled. - TORCH_CHECK( - t == ParallelType::Vectorize || t == ParallelType::TIDx || - t == ParallelType::Serial || t == ParallelType::Mma, - "Parallel type other than serial, tidx, vectorize not allowed for mma swizzled ids", - t); + // TORCH_CHECK( + // t == ParallelType::Vectorize || t == ParallelType::TIDx || + // t == ParallelType::Serial || t == ParallelType::Mma, + // "Parallel type other than serial, tidx, vectorize not allowed for mma swizzled ids", + // t); } parallel_type_ = t; diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp index b107d1ad36f98..38404a8b3d5e1 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp @@ -92,16 +92,16 @@ void validateDoubleBufferedTensor(const TensorView* tv) { // are allowed. const auto p_mem_type = producer->getMemoryType(); const auto c_mem_type = tv->getMemoryType(); - TORCH_INTERNAL_ASSERT( - (p_mem_type == MemoryType::Global && - (c_mem_type == MemoryType::Shared || c_mem_type == MemoryType::Local)) || - (c_mem_type == MemoryType::Local), - "Invalid tensor to double-buffer: ", - tv->toString(), - ". Producer memory type: ", - p_mem_type, - ". Consumer memory type: ", - c_mem_type); + // TORCH_INTERNAL_ASSERT( + // (p_mem_type == MemoryType::Global && + // (c_mem_type == MemoryType::Shared || c_mem_type == MemoryType::Local)) || + // (c_mem_type == MemoryType::Local), + // "Invalid tensor to double-buffer: ", + // tv->toString(), + // ". Producer memory type: ", + // p_mem_type, + // ". Consumer memory type: ", + // c_mem_type); return; } @@ -146,7 +146,9 @@ class DoubleBufferFusionInspector : private IterVisitor { bool requireEpilogue(const std::vector& exprs) { return std::any_of(exprs.begin(), exprs.end(), [](const Expr* expr) { return expr->input(0)->as()->getMemoryType() == - MemoryType::Shared; + MemoryType::Shared || + expr->input(0)->as()->getMemoryType() == + MemoryType::Local; }); } diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index cb2b55a4e984b..89a7da1b50287 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -509,6 +509,10 @@ BasicAllocInfo getAllocInformation( outer_alloc_found = true; } + if(tv->getMemoryType()==MemoryType::Shared && !fl_id->isThread()){ + outer_alloc_found = true; + } + // Allocation of a double buffered tensor is placed outside its // double buffer axis. if (tv->isDoubleBuffered() || tv->isCircularBuffered()) { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp index 7435f91bedd83..0e29c752d185e 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp @@ -311,6 +311,9 @@ void scheduleMatmul( // Setup accumulator register. auto cc = c->cacheBefore(); + // Setup output smem buffer + auto c_smem = c->cacheBefore(); + // Get the input to the mma op. auto mma = dynamic_cast(cc->definition()); TORCH_INTERNAL_ASSERT(mma != nullptr); @@ -481,6 +484,7 @@ void scheduleMatmul( // Set memory type: acw_smem->setMemoryType(MemoryType::Shared); bcw_smem->setMemoryType(MemoryType::Shared); + c_smem->setMemoryType(MemoryType::Shared); // Set parallelization: // TODO: this section goes to a separate matmul util, @@ -495,11 +499,11 @@ void scheduleMatmul( bcr->axis(-1)->parallelize(ParallelType::Vectorize); // 0 1 2 3 4 5 6 7 8 9 10 - // [Mo No Ko Kw Mwo Nwo Mw Nw (Mi Ni Ki)] + // [Mo No Ko Kw Mw Nw Mwo Nwo(Mi Ni Ki)] cc->axis(0)->parallelize(ParallelType::BIDx); cc->axis(1)->parallelize(ParallelType::BIDy); - cc->axis(4)->parallelize(ParallelType::TIDz); - cc->axis(5)->parallelize(ParallelType::TIDy); + cc->axis(6)->parallelize(ParallelType::TIDz); + cc->axis(7)->parallelize(ParallelType::TIDy); scheduler_utils::parallelizeAllLike( cc, @@ -538,12 +542,35 @@ void scheduleMatmul( scheduler_utils::BoundedDirectionalTransformPropagator::forward( cc, -1, + {c_smem}, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType() + .propagateToBoundary()); + + // Epilog schedule: + scheduler_utils::BoundedDirectionalTransformPropagator::forward( + c_smem, + 3, {c}, scheduler_utils::BoundedDirectionalTransformPropagator::Options() .propagateParallelType() .propagateToBoundary()); + c_smem->computeAt(c, 3); + c->reorder({{-1,-2}, {-2,-1}}); + // 16 x 128, with half of the warps: + + // Output vectorize by 4: + c->split(-2, 2); + c->split(-1, 4); + + // [8, 2, 32, 4] + c->axis(-3)->parallelize(ParallelType::TIDy); + c->axis(-2)->parallelize(ParallelType::TIDx); c->axis(-1)->parallelize(ParallelType::Vectorize); + c_smem->axis(-1)->parallelize(ParallelType::Vectorize); + c_smem->doubleBuffer(); + if (params.index_lift_options.lift_gmem_read_address) { a->liftReadAddress(); From 811a04186d616a524a34055f6f719f5533618f2d Mon Sep 17 00:00:00 2001 From: shmsong Date: Fri, 9 Sep 2022 11:01:51 -0700 Subject: [PATCH 08/10] minor clean up --- torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 4 +- .../jit/codegen/cuda/lower_double_buffer.cpp | 8 +- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 2 +- .../jit/codegen/cuda/scheduler/matmul.cpp | 3 +- .../codegen/cuda/test/test_gpu_tensorcore.cpp | 108 ++++++++++++++++++ 5 files changed, 116 insertions(+), 9 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index be2cd10372cdb..19ce379fdd832 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -1405,8 +1405,8 @@ void IterDomain::parallelize(ParallelType t) { // TORCH_CHECK( // t == ParallelType::Vectorize || t == ParallelType::TIDx || // t == ParallelType::Serial || t == ParallelType::Mma, - // "Parallel type other than serial, tidx, vectorize not allowed for mma swizzled ids", - // t); + // "Parallel type other than serial, tidx, vectorize not allowed for mma + // swizzled ids", t); } parallel_type_ = t; diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp index 38404a8b3d5e1..20f774ebd1239 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp @@ -94,7 +94,8 @@ void validateDoubleBufferedTensor(const TensorView* tv) { const auto c_mem_type = tv->getMemoryType(); // TORCH_INTERNAL_ASSERT( // (p_mem_type == MemoryType::Global && - // (c_mem_type == MemoryType::Shared || c_mem_type == MemoryType::Local)) || + // (c_mem_type == MemoryType::Shared || c_mem_type == MemoryType::Local)) + // || // (c_mem_type == MemoryType::Local), // "Invalid tensor to double-buffer: ", // tv->toString(), @@ -146,9 +147,8 @@ class DoubleBufferFusionInspector : private IterVisitor { bool requireEpilogue(const std::vector& exprs) { return std::any_of(exprs.begin(), exprs.end(), [](const Expr* expr) { return expr->input(0)->as()->getMemoryType() == - MemoryType::Shared || - expr->input(0)->as()->getMemoryType() == - MemoryType::Local; + MemoryType::Shared || + expr->input(0)->as()->getMemoryType() == MemoryType::Local; }); } diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 89a7da1b50287..2c295a4729721 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -509,7 +509,7 @@ BasicAllocInfo getAllocInformation( outer_alloc_found = true; } - if(tv->getMemoryType()==MemoryType::Shared && !fl_id->isThread()){ + if (tv->getMemoryType() == MemoryType::Shared && !fl_id->isThread()) { outer_alloc_found = true; } diff --git a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp index 0e29c752d185e..2ea01f53c8609 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp @@ -557,7 +557,7 @@ void scheduleMatmul( .propagateToBoundary()); c_smem->computeAt(c, 3); - c->reorder({{-1,-2}, {-2,-1}}); + c->reorder({{-1, -2}, {-2, -1}}); // 16 x 128, with half of the warps: // Output vectorize by 4: @@ -571,7 +571,6 @@ void scheduleMatmul( c_smem->axis(-1)->parallelize(ParallelType::Vectorize); c_smem->doubleBuffer(); - if (params.index_lift_options.lift_gmem_read_address) { a->liftReadAddress(); b->liftReadAddress(); diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index 6f79dbf2f90f4..d34db2db8dedf 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -2826,6 +2826,114 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) { } } +TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoadQuickSanity1_CUDA) { + // Keep multiples of 8 to keep vectorizable. + int M = 2048, N = 3456, K = 1024; + for (auto layout : kAllSupportedLayout) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = matmul(tv0, tv1, layout); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 64); + gemm_tile.warp_tile = GemmTile(64, 64, 64); + gemm_tile.instruction_tile = GemmTile(16, 16, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile) + .layout(layout); + + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + params.async_gmem_load_operands = true; + params.double_buffer_options.double_buffer_smem_write = true; + params.double_buffer_options.double_buffer_smem_read = true; + params.double_buffer_options.smem_double_buffer_stage = 3; + scheduleMatmul(tv2, tv0, tv1, params); + + at::manual_seed(0); + auto inputs = fp16MatmulAtInput(M, N, K, layout); + + CompileOptions co; + co.index_mode = KernelIndexMode::INT32; + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, + 0, + fe.compileFusion( + &fusion, {inputs.first, inputs.second}, LaunchParams(), co)); + + // return; + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.001, 0.001)); + } +} + +TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoadQuickSanity2_CUDA) { + // Keep multiples of 8 to keep vectorizable. + int M = 2048, N = 3456, K = 1024; + for (auto layout : kAllSupportedLayout) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = matmul(tv0, tv1, layout); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(256, 128, 64); + gemm_tile.warp_tile = GemmTile(64, 64, 64); + gemm_tile.instruction_tile = GemmTile(16, 16, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile) + .layout(layout); + + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + params.async_gmem_load_operands = true; + params.double_buffer_options.double_buffer_smem_write = true; + params.double_buffer_options.double_buffer_smem_read = true; + params.double_buffer_options.smem_double_buffer_stage = 2; + scheduleMatmul(tv2, tv0, tv1, params); + + at::manual_seed(0); + auto inputs = fp16MatmulAtInput(M, N, K, layout); + + CompileOptions co; + co.index_mode = KernelIndexMode::INT32; + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, + 0, + fe.compileFusion( + &fusion, {inputs.first, inputs.second}, LaunchParams(), co)); + + // return; + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.001, 0.001)); + } +} + // Tile layout check for symmetric 4-warp recipes TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck4warp_CUDA) { REQUIRE_DEVICE_SMEM_SIZE(98384, 0); From 1e5e745efc4c85d53c908e998566cd0fdaf7a095 Mon Sep 17 00:00:00 2001 From: shmsong Date: Sun, 11 Sep 2022 15:29:57 -0700 Subject: [PATCH 09/10] optionally enable epilog schedule (for now) --- .../jit/codegen/cuda/scheduler/matmul.cpp | 68 +++++++++++-------- .../csrc/jit/codegen/cuda/scheduler/matmul.h | 3 + 2 files changed, 44 insertions(+), 27 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp index 2ea01f53c8609..3a74f87a68262 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp @@ -311,8 +311,12 @@ void scheduleMatmul( // Setup accumulator register. auto cc = c->cacheBefore(); - // Setup output smem buffer - auto c_smem = c->cacheBefore(); + TensorView* c_smem = nullptr; + + if (params.has_epilog) { + // Setup output smem buffer + c_smem = c->cacheBefore(); + } // Get the input to the mma op. auto mma = dynamic_cast(cc->definition()); @@ -484,7 +488,10 @@ void scheduleMatmul( // Set memory type: acw_smem->setMemoryType(MemoryType::Shared); bcw_smem->setMemoryType(MemoryType::Shared); - c_smem->setMemoryType(MemoryType::Shared); + + if (params.has_epilog) { + c_smem->setMemoryType(MemoryType::Shared); + } // Set parallelization: // TODO: this section goes to a separate matmul util, @@ -539,37 +546,44 @@ void scheduleMatmul( } } + auto output_buffer = params.has_epilog ? c_smem : c; + scheduler_utils::BoundedDirectionalTransformPropagator::forward( cc, -1, - {c_smem}, - scheduler_utils::BoundedDirectionalTransformPropagator::Options() - .propagateParallelType() - .propagateToBoundary()); - - // Epilog schedule: - scheduler_utils::BoundedDirectionalTransformPropagator::forward( - c_smem, - 3, - {c}, + {output_buffer}, scheduler_utils::BoundedDirectionalTransformPropagator::Options() .propagateParallelType() .propagateToBoundary()); - c_smem->computeAt(c, 3); - c->reorder({{-1, -2}, {-2, -1}}); - // 16 x 128, with half of the warps: - - // Output vectorize by 4: - c->split(-2, 2); - c->split(-1, 4); - - // [8, 2, 32, 4] - c->axis(-3)->parallelize(ParallelType::TIDy); - c->axis(-2)->parallelize(ParallelType::TIDx); - c->axis(-1)->parallelize(ParallelType::Vectorize); - c_smem->axis(-1)->parallelize(ParallelType::Vectorize); - c_smem->doubleBuffer(); + // Epilog schedule (To be built out): + if (params.has_epilog) { + scheduler_utils::BoundedDirectionalTransformPropagator::forward( + c_smem, + 3, + {c}, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType() + .propagateToBoundary()); + + c_smem->computeAt(c, 3); + c->reorder({{-1, -2}, {-2, -1}}); + // 16 x 128, with half of the warps: + + // Output vectorize by 4: + c->split(-2, 2); + c->split(-1, 4); + + // [8, 2, 32, 4] + c->axis(-3)->parallelize(ParallelType::TIDy); + c->axis(-2)->parallelize(ParallelType::TIDx); + c->axis(-1)->parallelize(ParallelType::Vectorize); + c_smem->axis(-1)->parallelize(ParallelType::Vectorize); + c_smem->doubleBuffer(); + } else { + // Always vector + c->axis(-1)->parallelize(ParallelType::Vectorize); + } if (params.index_lift_options.lift_gmem_read_address) { a->liftReadAddress(); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/matmul.h b/torch/csrc/jit/codegen/cuda/scheduler/matmul.h index 354e2affeab04..e105b2e7e02cf 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/matmul.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/matmul.h @@ -51,6 +51,9 @@ class MatmulParam { //! Enables predicate peeling mainloop: bool peel_main_loop = true; + + //! Enables an epilog schedule + bool has_epilog = false; }; //! Prototype auto scheduling function. From 2e37ff1b69befa8803a0c180d89c7441434531f1 Mon Sep 17 00:00:00 2001 From: shmsong Date: Fri, 23 Sep 2022 18:09:09 -0700 Subject: [PATCH 10/10] rebase fix (as promised :) ). see FusionAmpereMatmulLargeLoad for a working example. --- .../jit/codegen/cuda/scheduler/matmul.cpp | 25 +++- .../csrc/jit/codegen/cuda/scheduler/matmul.h | 2 +- .../codegen/cuda/test/test_gpu_tensorcore.cpp | 108 ------------------ 3 files changed, 21 insertions(+), 114 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp index 3a74f87a68262..b6b9cd8c5fb81 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp @@ -404,6 +404,13 @@ void scheduleMatmul( scheduler_utils::matmul_utils::scheduleWarpTileWithReduction( cc, gemm_tile, true); + // Move the Mw up for epilog: + if (params.has_epilog) { + cc->reorder({{4, 5}, {5, 6}, {6, 4}}); + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Kw Mw Mwo Nwo Nw (Mi Ni Ki)] + } + // Propagate warp tile to main loop and epilog/output tvs scheduler_utils::BoundedDirectionalTransformPropagator::bothWays( cc, -1, {acw_smem, bcw_smem}, {c}); @@ -509,8 +516,15 @@ void scheduleMatmul( // [Mo No Ko Kw Mw Nw Mwo Nwo(Mi Ni Ki)] cc->axis(0)->parallelize(ParallelType::BIDx); cc->axis(1)->parallelize(ParallelType::BIDy); - cc->axis(6)->parallelize(ParallelType::TIDz); - cc->axis(7)->parallelize(ParallelType::TIDy); + + // Maybe just keep one of these two options. + if (params.has_epilog) { + cc->axis(5)->parallelize(ParallelType::TIDz); + cc->axis(6)->parallelize(ParallelType::TIDy); + } else { + cc->axis(4)->parallelize(ParallelType::TIDz); + cc->axis(5)->parallelize(ParallelType::TIDy); + } scheduler_utils::parallelizeAllLike( cc, @@ -560,15 +574,15 @@ void scheduleMatmul( if (params.has_epilog) { scheduler_utils::BoundedDirectionalTransformPropagator::forward( c_smem, - 3, + 4, {c}, scheduler_utils::BoundedDirectionalTransformPropagator::Options() .propagateParallelType() .propagateToBoundary()); - c_smem->computeAt(c, 3); + c_smem->computeAt(c, 4); c->reorder({{-1, -2}, {-2, -1}}); - // 16 x 128, with half of the warps: + // 16 x 128, with 2 warps: // Output vectorize by 4: c->split(-2, 2); @@ -577,6 +591,7 @@ void scheduleMatmul( // [8, 2, 32, 4] c->axis(-3)->parallelize(ParallelType::TIDy); c->axis(-2)->parallelize(ParallelType::TIDx); + c->axis(-1)->parallelize(ParallelType::Vectorize); c_smem->axis(-1)->parallelize(ParallelType::Vectorize); c_smem->doubleBuffer(); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/matmul.h b/torch/csrc/jit/codegen/cuda/scheduler/matmul.h index e105b2e7e02cf..41cbcc3052281 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/matmul.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/matmul.h @@ -53,7 +53,7 @@ class MatmulParam { bool peel_main_loop = true; //! Enables an epilog schedule - bool has_epilog = false; + bool has_epilog = true; }; //! Prototype auto scheduling function. diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index d34db2db8dedf..6f79dbf2f90f4 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -2826,114 +2826,6 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) { } } -TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoadQuickSanity1_CUDA) { - // Keep multiples of 8 to keep vectorizable. - int M = 2048, N = 3456, K = 1024; - for (auto layout : kAllSupportedLayout) { - Fusion fusion; - FusionGuard fg(&fusion); - auto tv0 = makeContigTensor(2, DataType::Half); - auto tv1 = makeContigTensor(2, DataType::Half); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = matmul(tv0, tv1, layout); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(128, 128, 64); - gemm_tile.warp_tile = GemmTile(64, 64, 64); - gemm_tile.instruction_tile = GemmTile(16, 16, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile) - .layout(layout); - - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - params.async_gmem_load_operands = true; - params.double_buffer_options.double_buffer_smem_write = true; - params.double_buffer_options.double_buffer_smem_read = true; - params.double_buffer_options.smem_double_buffer_stage = 3; - scheduleMatmul(tv2, tv0, tv1, params); - - at::manual_seed(0); - auto inputs = fp16MatmulAtInput(M, N, K, layout); - - CompileOptions co; - co.index_mode = KernelIndexMode::INT32; - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, - 0, - fe.compileFusion( - &fusion, {inputs.first, inputs.second}, LaunchParams(), co)); - - // return; - auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); - auto tref = atMatmul( - inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.001, 0.001)); - } -} - -TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoadQuickSanity2_CUDA) { - // Keep multiples of 8 to keep vectorizable. - int M = 2048, N = 3456, K = 1024; - for (auto layout : kAllSupportedLayout) { - Fusion fusion; - FusionGuard fg(&fusion); - auto tv0 = makeContigTensor(2, DataType::Half); - auto tv1 = makeContigTensor(2, DataType::Half); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = matmul(tv0, tv1, layout); - - fusion.addOutput(tv2); - - MatMulTileOptions gemm_tile; - gemm_tile.cta_tile = GemmTile(256, 128, 64); - gemm_tile.warp_tile = GemmTile(64, 64, 64); - gemm_tile.instruction_tile = GemmTile(16, 16, 16); - - auto mma_builder = - MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile) - .layout(layout); - - MatmulParam params(mma_builder); - params.tile_sizes = gemm_tile; - params.async_gmem_load_operands = true; - params.double_buffer_options.double_buffer_smem_write = true; - params.double_buffer_options.double_buffer_smem_read = true; - params.double_buffer_options.smem_double_buffer_stage = 2; - scheduleMatmul(tv2, tv0, tv1, params); - - at::manual_seed(0); - auto inputs = fp16MatmulAtInput(M, N, K, layout); - - CompileOptions co; - co.index_mode = KernelIndexMode::INT32; - - FusionExecutor fe; - NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, - 0, - fe.compileFusion( - &fusion, {inputs.first, inputs.second}, LaunchParams(), co)); - - // return; - auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); - auto tref = atMatmul( - inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); - TORCH_CHECK(cg_outputs[0].allclose(tref, 0.001, 0.001)); - } -} - // Tile layout check for symmetric 4-warp recipes TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck4warp_CUDA) { REQUIRE_DEVICE_SMEM_SIZE(98384, 0);