Skip to content

Commit

Permalink
[ROCm] slow torch.sum optimization by increasing max_values_per_threa… (
Browse files Browse the repository at this point in the history
#1588)

…d in reduce config (pytorch#135397)

Fixes pytorch#132964

This change is to optimize torch.sum() performance by increasing
max_values_per_thread in setReduceConfig() for ROCm platform. By
increasing this parameter, it uses fewer threadblocks and improved the
performance.

Test:
Tested on MI300x and H100, and now the MI300x perf improved to
3205GByte/s from ~1690GByte/s for the test case and is slightly better
than H100 (3136GByte/s).

Also tested with other different sizes of tensors and also see perf
improvement.

```python
import torch
from triton.testing import do_bench

x = torch.randn(2**30, device='cuda')

ms = do_bench(lambda: x.sum(dim=-1))

bandwidth_gbyte = x.numel() * x.dtype.itemsize / (10**9)

time_s = ms / 1000

bw_per_second = bandwidth_gbyte / time_s

print(bw_per_second)
```

Co-author: @carlobertolli

Pull Request resolved: pytorch#135397
Approved by: https://github.com/eqy, https://github.com/malfet

Fixes #ISSUE_NUMBER

Co-authored-by: hongxyan <[email protected]>
  • Loading branch information
jerrymannil and hongxiayang authored Sep 12, 2024
1 parent f4c8ad5 commit 4360582
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions aten/src/ATen/native/cuda/Reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,11 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
}

constexpr int min_values_per_thread = 16;
#ifndef USE_ROCM
constexpr int max_values_per_thread = 256;
#else
constexpr int max_values_per_thread = 1024;
#endif

if (config.values_per_thread() >= block_height * 16 || config.values_per_thread() >= max_values_per_thread) {
// Divide the input across warps in a thread-block, if that leaves at least
Expand Down

0 comments on commit 4360582

Please sign in to comment.