From 4360582f00231e9e1650278e07ce7664a345f8e5 Mon Sep 17 00:00:00 2001 From: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com> Date: Thu, 12 Sep 2024 12:42:55 -0700 Subject: [PATCH] =?UTF-8?q?[ROCm]=20slow=20torch.sum=20optimization=20by?= =?UTF-8?q?=20increasing=20max=5Fvalues=5Fper=5Fthrea=E2=80=A6=20(#1588)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …d in reduce config (#135397) Fixes #132964 This change is to optimize torch.sum() performance by increasing max_values_per_thread in setReduceConfig() for ROCm platform. By increasing this parameter, it uses fewer threadblocks and improved the performance. Test: Tested on MI300x and H100, and now the MI300x perf improved to 3205GByte/s from ~1690GByte/s for the test case and is slightly better than H100 (3136GByte/s). Also tested with other different sizes of tensors and also see perf improvement. ```python import torch from triton.testing import do_bench x = torch.randn(2**30, device='cuda') ms = do_bench(lambda: x.sum(dim=-1)) bandwidth_gbyte = x.numel() * x.dtype.itemsize / (10**9) time_s = ms / 1000 bw_per_second = bandwidth_gbyte / time_s print(bw_per_second) ``` Co-author: @carlobertolli Pull Request resolved: https://github.com/pytorch/pytorch/pull/135397 Approved by: https://github.com/eqy, https://github.com/malfet Fixes #ISSUE_NUMBER Co-authored-by: hongxyan --- aten/src/ATen/native/cuda/Reduce.cuh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 85bde8b5990ff..9d4a9df2f3e59 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -1092,7 +1092,11 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ } constexpr int min_values_per_thread = 16; +#ifndef USE_ROCM constexpr int max_values_per_thread = 256; +#else + constexpr int max_values_per_thread = 1024; +#endif if (config.values_per_thread() >= block_height * 16 || config.values_per_thread() >= max_values_per_thread) { // Divide the input across warps in a thread-block, if that leaves at least