Skip to content

Commit

Permalink
[ROCm] torch.sum optimization by increasing min_values_per_thread (#1591
Browse files Browse the repository at this point in the history
)

Follow-up to pytorch#135397.
AMD gpus perform better with fewer thread blocks.
So increase the min_values_per_thread as well. 
This helped improved
[CvT](https://github.com/facebookresearch/FAMBench/tree/main/benchmarks/cvt)
benchmark performance on MI300X

Co-author: @carlobertolli
  • Loading branch information
jerrymannil authored Sep 13, 2024
1 parent 4360582 commit c1b6f60
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions aten/src/ATen/native/cuda/Reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1091,14 +1091,16 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
config.output_mult[0] = config.split_output(block_width);
}

#ifdef USE_ROCM
// AMD gpus perform better with fewer thread blocks
constexpr int min_values_per_thread = 128;
constexpr int max_values_per_thread = 1024;
#else
constexpr int min_values_per_thread = 16;
#ifndef USE_ROCM
constexpr int max_values_per_thread = 256;
#else
constexpr int max_values_per_thread = 1024;
#endif

if (config.values_per_thread() >= block_height * 16 || config.values_per_thread() >= max_values_per_thread) {
if (config.values_per_thread() >= block_height * min_values_per_thread || config.values_per_thread() >= max_values_per_thread) {
// Divide the input across warps in a thread-block, if that leaves at least
// 16 elements to be summed by each thread. This will require inter-warp
// reduction using shared memory.
Expand Down

0 comments on commit c1b6f60

Please sign in to comment.