From a2dfe40b27cd3f5c04207596f0a1818fbd5e5439 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Thu, 6 Oct 2022 08:12:49 -0700 Subject: [PATCH] Fix vectorize size calculation (#2035) --- .../jit/codegen/cuda/scheduler/registry.cpp | 25 +++ .../jit/codegen/cuda/scheduler/registry.h | 4 + .../cuda/scheduler/vectorize_helper.cpp | 30 ++-- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 164 +++++++++++++++++- 4 files changed, 202 insertions(+), 21 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp index 3161aeb10b4991..644a4d9e7ebd8b 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp @@ -463,6 +463,24 @@ void SchedulerRuntimeInfo::initialize( auto fusion_inp = complete_fusion_->inputs()[inp_i]; auto data_ptr = tensor_arg_abstract->getPointer(); input_ptrs_[fusion_inp] = (size_t)data_ptr; + + // find and push discontiguous stride + auto dtype_size = dataTypeSize(tensor_arg_abstract->getDataType()); + input_discontig_strides_[fusion_inp] = {}; + auto dims = tensor_arg_abstract->getRank(); + auto expected_stride = 1; + for (auto dim = dims - 1; dim >= 0; dim--) { + auto size = tensor_arg_abstract->getSize(dim); + if (size <= 1) { + continue; + } + auto stride = tensor_arg_abstract->getStride(dim); + if (stride != expected_stride) { + input_discontig_strides_[fusion_inp].push_back(stride * dtype_size); + expected_stride = stride; + } + expected_stride *= size; + } } } @@ -529,6 +547,13 @@ size_t SchedulerRuntimeInfo::getAlignmentSize(TensorView* tv) { } auto alignment_size = SchedulerRuntimeInfo::computeAlignmentSize(ptrOf(tv)); + auto strides_it = input_discontig_strides_.find(tv); + if (strides_it != input_discontig_strides_.end()) { + for (auto stride : strides_it->second) { + alignment_size = std::min( + alignment_size, SchedulerRuntimeInfo::computeAlignmentSize(stride)); + } + } alignment_map_[tv] = alignment_size; return alignment_size; } diff --git a/torch/csrc/jit/codegen/cuda/scheduler/registry.h b/torch/csrc/jit/codegen/cuda/scheduler/registry.h index 7ed8474935c011..8b34094476349c 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/registry.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/registry.h @@ -27,6 +27,7 @@ class ExpressionEvaluator; //! segmenter and schedulers. //! It is important that input id encoding should be up to date with any change //! of this class to avoid launching compiled kernels with illegal inputs. + class TORCH_CUDA_CU_API SchedulerRuntimeInfo : public NonCopyable { public: // Max vector size we will consider, in bytes, @@ -112,6 +113,9 @@ class TORCH_CUDA_CU_API SchedulerRuntimeInfo : public NonCopyable { // TODO: Support output tensor pointers std::unordered_map input_ptrs_; + // Copy of aten input tensor strides (in bytes) + std::unordered_map> input_discontig_strides_; + // Cache for getAlignmentSize std::unordered_map alignment_map_; // Cache for getMaxVectorizableWidth diff --git a/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp b/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp index 8411a3c112405d..2c3c848c7f5c9c 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp @@ -80,18 +80,6 @@ size_t collectMaxVectorizeSizeWithContigMerge( size_t max_vector_size_in_byte, ExpressionEvaluator& expression_evaluator, DataType index_type) { - // Maybe too conservative, but only handles fully contiguous tensors - // TODO: Relax the contiguity constraint to be similar to that in index - // computing. Just looking for all merged root domains in the right order, - // all merged root dimensions are contiguous, all merged root dimensions are - // next to eachother (exlcuding broadcast). - if (std::any_of( - tv->domain()->contiguity().begin(), - tv->domain()->contiguity().end(), - [](const auto contig) { return !contig; })) { - return 1; - } - auto dtype_size = dataTypeSize(tv->dtype(), index_type); const size_t max_vector_size = max_vector_size_in_byte / dtype_size; @@ -202,8 +190,16 @@ size_t expandVectorizationToContigMergedDomains( // Merge the domains right of the break point const auto& ref_root = reference_tv->getMaybeRFactorDomain(); - const int num_merged_domains = + const int max_num_merged_domains = static_cast(ref_root.size()) - static_cast(break_point); + int64_t num_merged_domains = 0; + while (num_merged_domains < max_num_merged_domains) { + auto pos = (int64_t)ref_root.size() - 1 - num_merged_domains; + if (!reference_tv->domain()->contiguity()[pos]) { + break; + } + num_merged_domains++; + } // No expansion with no merged domain if (num_merged_domains == 0) { @@ -242,14 +238,16 @@ size_t expandVectorizationToContigMergedDomains( const auto& tv_root = tv->getMaybeRFactorDomain(); int tv_num_merged_domains = 0; - for (const auto i : c10::irange(num_merged_domains)) { + for (const auto i : c10::irange(max_num_merged_domains)) { if (i == tv_root.size()) { break; } auto ref_id = ref_root.at(ref_root.size() - 1 - i); - IterDomain* tv_id = tv_root.at(tv_root.size() - 1 - i); + auto pos = tv_root.size() - 1 - i; + IterDomain* tv_id = tv_root.at(pos); // If not mapped, stop expanding. - if (!ca_map.areMapped(ref_id, tv_id, IdMappingMode::EXACT)) { + if (!ca_map.areMapped(ref_id, tv_id, IdMappingMode::EXACT) || + !tv->domain()->contiguity()[pos]) { break; } else { ++tv_num_merged_domains; diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 0252a7785d6764..33ef44711525c1 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -19138,9 +19138,9 @@ TEST_F(NVFuserTest, FusionChannelsLastParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, Tensor<__half, 4> T7) { - int64_t i171; - i171 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); - if ((i171 < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) { + int64_t i165; + i165 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); + if ((i165 < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) { __half T9[1]; T9[0] = 0; T9[0] @@ -19148,7 +19148,7 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, __half T8[1]; T8[0] = 0; T8[0] - = T0[i171]; + = T0[i165]; float T3[1]; T3[0] = __half2float(T9[0]); @@ -19168,7 +19168,7 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, __half T10[1]; T10[0] = __float2half(T6[0]); - T7[i171] + T7[i165] = T10[0]; } } @@ -26125,6 +26125,160 @@ TEST_F(NVFuserTest, FusionTrivialInputForwarding_CUDA) { testValidate(fusion, cg_outputs2, {t0, t1}, {t0}, __LINE__, __FILE__); } +namespace { + +size_t getVecSizeForPointwise(FusionExecutorCache& fec) { + auto most_recent_params = + fec.getMostRecentKernelRuntime()->getMostRecentExecutorLog().params; + auto params = dynamic_cast(most_recent_params.get()); + if (params->vectorize) { + return params->unroll_factor; + } + return 1; +} + +} // namespace + +TEST_F(NVFuserTest, FusionVectorizeStrideContiguity2D_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = + TensorViewBuilder().ndims(2).contiguity({false, true}).build(); + fusion->addInput(tv0); + auto tv1 = set(tv0); + fusion->addOutput(tv1); + + FusionExecutorCache fec(std::move(fusion_ptr)); + fec.profile(true); + + std::vector> size_and_vec{{17, 1}, {18, 2}, {32, 4}}; + + for (auto pair : size_and_vec) { + auto size = pair.first; + auto vec = pair.second; + auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({1000000, size}, options).narrow(1, 0, 16); + auto cg_outputs = fec.runFusionWithInputs({t0}); + + TORCH_CHECK(getVecSizeForPointwise(fec) == vec); + + testValidate(fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + } +} + +TEST_F(NVFuserTest, FusionVectorizeStrideContiguity3D_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = + TensorViewBuilder().ndims(3).contiguity({false, true, true}).build(); + fusion->addInput(tv0); + auto tv1 = set(tv0); + fusion->addOutput(tv1); + + FusionExecutorCache fec(std::move(fusion_ptr)); + fec.profile(true); + + std::vector> size_and_vec{{17, 1}, {10, 2}, {16, 4}}; + + for (auto pair : size_and_vec) { + auto size = pair.first; + auto vec = pair.second; + auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({1000000, size, 3}, options).narrow(1, 0, 8); + auto cg_outputs = fec.runFusionWithInputs({t0}); + + TORCH_CHECK(getVecSizeForPointwise(fec) == vec); + + testValidate(fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + } +} + +TEST_F(NVFuserTest, FusionVectorizeStrideContiguity5D_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = TensorViewBuilder() + .ndims(5) + .contiguity({false, true, false, true, true}) + .build(); + fusion->addInput(tv0); + auto tv1 = set(tv0); + fusion->addOutput(tv1); + + FusionExecutorCache fec(std::move(fusion_ptr)); + fec.profile(true); + + auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); + + std::vector> sizes_and_vec{ + {9, 17, 1}, {9, 10, 2}, {9, 16, 4}}; + + for (auto tup : sizes_and_vec) { + auto size1 = std::get<0>(tup); + auto size2 = std::get<1>(tup); + auto vec = std::get<2>(tup); + at::Tensor t0 = at::randn({4, size1, 12345, size2, 3}, options) + .narrow(1, 0, 8) + .narrow(3, 0, 4); + auto cg_outputs = fec.runFusionWithInputs({t0}); + + TORCH_CHECK(getVecSizeForPointwise(fec) == vec); + + testValidate(fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + } +} + +TEST_F(NVFuserTest, FusionVectorizeStrideContiguitySelfOverlapping_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = TensorViewBuilder() + .ndims(5) + .contiguity({false, true, false, true, true}) + .build(); + fusion->addInput(tv0); + auto tv1 = set(tv0); + fusion->addOutput(tv1); + + FusionExecutorCache fec(std::move(fusion_ptr)); + fec.profile(true); + + auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); + + std::vector> sizes_strides_and_vec{ + {4, 4, 4, 4}, + {4, 4, 2, 2}, + {4, 2, 4, 2}, + {2, 4, 4, 2}, + {4, 4, 1, 1}, + {4, 1, 4, 1}, + {1, 4, 4, 1}, + {2, 2, 2, 2}, + {2, 2, 1, 1}, + {2, 1, 2, 1}, + {1, 2, 2, 1}}; + + for (auto tup : sizes_strides_and_vec) { + auto size = std::get<0>(tup); + auto stride1 = std::get<1>(tup); + auto stride2 = std::get<2>(tup); + auto vec = std::get<3>(tup); + std::vector shape = {4, 4, 12345, size, 3}; + std::vector stride = {stride1, stride2 * 12345, stride2, 3, 1}; + at::Tensor t0 = at::empty_strided(shape, stride, options); + t0.random_(); + auto cg_outputs = fec.runFusionWithInputs({t0}); + TORCH_CHECK(getVecSizeForPointwise(fec) == vec); + testValidate(fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + } +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA)