Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ROCm] slow torch.sum optimization by increasing max_values_per_threa… (
#1588) …d in reduce config (pytorch#135397) Fixes pytorch#132964 This change is to optimize torch.sum() performance by increasing max_values_per_thread in setReduceConfig() for ROCm platform. By increasing this parameter, it uses fewer threadblocks and improved the performance. Test: Tested on MI300x and H100, and now the MI300x perf improved to 3205GByte/s from ~1690GByte/s for the test case and is slightly better than H100 (3136GByte/s). Also tested with other different sizes of tensors and also see perf improvement. ```python import torch from triton.testing import do_bench x = torch.randn(2**30, device='cuda') ms = do_bench(lambda: x.sum(dim=-1)) bandwidth_gbyte = x.numel() * x.dtype.itemsize / (10**9) time_s = ms / 1000 bw_per_second = bandwidth_gbyte / time_s print(bw_per_second) ``` Co-author: @carlobertolli Pull Request resolved: pytorch#135397 Approved by: https://github.com/eqy, https://github.com/malfet Fixes #ISSUE_NUMBER Co-authored-by: hongxyan <[email protected]>
- Loading branch information