diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 85bde8b5990ffa..9d4a9df2f3e590 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -1092,7 +1092,11 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ } constexpr int min_values_per_thread = 16; +#ifndef USE_ROCM constexpr int max_values_per_thread = 256; +#else + constexpr int max_values_per_thread = 1024; +#endif if (config.values_per_thread() >= block_height * 16 || config.values_per_thread() >= max_values_per_thread) { // Divide the input across warps in a thread-block, if that leaves at least