Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize the implementation of the argsort operator. #47738

Merged
merged 19 commits into from
Nov 29, 2022

Conversation

Vvsmile
Copy link
Contributor

@Vvsmile Vvsmile commented Nov 7, 2022

PR types

Performance optimization

PR changes

OPs

Describe

Optimize the implementation of argsort operator. The performance is better than Pytorch.
Given on the input size of [4, 68992], which is slower than Pytorch initially, the performance of optimized argsort kernel can achieve 3.3x speedup, which is ~1.5x the performance of Pytorch.
Reorganize the code of argsort operator to reduce the code redundancy.
This optimal codes are called when the number of rows of input tensor is less.
When the number of input tensor is large, the original implementation is called. (The performance is also better than Pytorch)

case pytorch paddle (before optimization) diff paddle (after optimization) diff speedup
x(Variable)-dtype: float32, shape: [1L, 68892L] axis(int):-1 0.17 0.23 lower than(35.3%) 0.20 lower than(17.6%) 1.15
x(Variable)-dtype: float32, shape: [2L, 68892L] axis(int):-1 0.25 0.91 lower than(2.64) 0.28 lower than(12%) 3.25
x(Variable)-dtype: float32, shape: [4L, 68892L] axis(int):-1 0.42 0.97 lower than(1.31) 0.29 greater than(31.0%) 3.34
x(Variable)-dtype: float32, shape: [8L, 68892L] axis(int):-1 0.71 1.03 lower than(45.1%) 0.39 greater than(45.1%) 2.64
x(Variable)-dtype: float32, shape: [16L, 68892L] axis(int):-1 1.21 1.07 greater than(11.6%) 1.04 greater than(14.0%) 1.03
x(Variable)-dtype: float32, shape: [32L, 68892L] axis(int):-1 2.23 1.31 greater than(41.3%) 1.11 greater than(50.2%) 1.18

Reorganize the code to reduce the code redundancy.

The optimal codes are called when the number of rows is limited
in the input tensor.

modification:
	paddle/phi/kernels/gpu/argsort_kernel.cu

test=develop
@paddle-bot
Copy link

paddle-bot bot commented Nov 7, 2022

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added contributor External developers status: proposed labels Nov 7, 2022
@luotao1 luotao1 removed the contributor External developers label Nov 8, 2022
@@ -76,13 +76,285 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
}
}

// this is defined for CUDA LOOP
#define CUDA_KERNEL_LOOP_TYPE(i, n, index_type) \
Copy link
Contributor

@ZzSean ZzSean Nov 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个宏在Paddle中已经定义好了,在Paddle/paddle/phi/backends/gpu/cuda/cuda_helper.h

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,那我把这里给去掉

_i_n_d_e_x += blockDim.x * gridDim.x, i = _i_n_d_e_x)
#define CUDA_KERNEL_LOOP(i, n) CUDA_KERNEL_LOOP_TYPE(i, n, int)

static __global__ void Fill_index_and_segment_kernel(int2 *data,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 这里为什么要加static
  2. 函数名命名规则用大驼峰

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我换成大驼峰去,
外面没有这里的调用,想着直接定在模块里面,也可以去掉,全局里没有同名函数


constexpr int CUDA_NUM_THREADS = 1024;

inline int GET_BLOCKS(const int64_t N,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okk


inline int GET_BLOCKS(const int64_t N,
const int64_t max_threads_per_block = CUDA_NUM_THREADS) {
constexpr int64_t max_int = std::numeric_limits<int>::max();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个值好像没用到

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

存在这个冗余,已删除


TransposeKernel<int64_t, Context>(dev_ctx, tmp_indices, trans, indices);
// transpose back
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些注释不用删吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我把这个注释换上,当初在这些里面读代码的时候加了注释,后面删删多了


// Use thrust for parallel acceleration when the input size is equal to the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些注释为什么要删掉呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

DenseTensor* output,
DenseTensor* indices) {
auto in_dims = input.dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这句也不能删吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯,确实不能

self_ptr, values_ptr, indices_ptr, i_s_ptr, nsegments, nsort);
}

// Sort by flag descending, True: descending. False: Ascending.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释格式有点乱

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我全都过一遍

const IndType num_rows,
const IndType num_cols,
const bool descending) {
auto cu_stream = ctx.stream();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cu_stream->gpu_stream

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

ctx.template Alloc<IndType>(&input_indices);
size_t temp_storage_bytes = -1;

int block_size = 1024;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用魔鬼数字,用你前面宏定义的CUDA_NUM_THREADS

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个我改掉了,还没有推上来


int block_size = 1024;
int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0];
int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX;
Copy link
Contributor

@ZzSean ZzSean Nov 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你这里的grid和block设置跟最上面的GET_BLOCKS两者之间好像没有什么关系?尽量把相同相似的逻辑抽取出来,封装使用,避免重复代码
而且感觉你这里算的block和grid并没有用到啊

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个我已经删掉了,还没有推上来

@paddle-bot-old paddle-bot-old bot added the contributor External developers label Nov 8, 2022
Reorganize the code to reduce the code redundancy.

The optimal codes are called when the number of rows is limited
in the input tensor.

modification:
	paddle/phi/kernels/gpu/argsort_kernel.cu

test=develop
This version passes all of the unittests.

Merge action:
	Merge branch 'develop' of https://github.com/Vvsmile/Paddle into develop
This version passes all of the unittests.

Merge action:
	Merge branch 'develop' of https://github.com/Vvsmile/Paddle into develop
Modify the argsort operator to pass the speed testing.
@@ -64,8 +64,12 @@ struct SegmentOffsetIter {
int num_cols_;
};

// using phi::paddle::platform::PADDLE_CUDA_NUM_THREADS;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以删

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好嘞

PADDLE_ENFORCE_GPU_SUCCESS(err); \
} while (false)

inline int GET_BLOCKS(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数名大驼峰,另外如果这个函数功能这么简单的话,是不是可以不用封装了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

明白,我把这个弄一下

auto in_dims = input.dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
const T* in_data = input.data<T>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分感觉不用改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好嘞

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个我查了一下,没有删掉,就是我之前编辑的时候确实不小心删掉过,但是我又加上了,只是位置不在这个地方,在这个后两行的位置,不影响程序执行

Delete some unuseful codes and change the function name to hump format.
}
}

template <typename scalar_t>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

模版参数还是用T吧,下同

segment_bits,
ctx);

SortPostprocessKernel<<<(n + 511) / 512, 512, 0, cu_stream>>>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个block和grid写法有点暴力啊。。还是用变量吧,或者上面的结果不能直接用吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

换成了上面的结果形式


// The method is called when # of the rows of the input is less than or equal to
// 32
template <typename T, typename IndType>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IndType->IndexType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修

@ZzSean
Copy link
Contributor

ZzSean commented Nov 25, 2022

优化效果在PR描述中补充一下

template <typename KT, typename VT>
static void RadixSortPairsImpl(const KT *keys_in,
KT *keys_out,
const VT *values_in,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

变量顺序把输入和输出放在一起,输入在前输出在后

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

bool descending,
int64_t begin_bit,
int64_t end_bit,
const phi::GPUContext &ctx) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个ctx参数放在最前面

const phi::GPUContext &ctx) {
if (keys_out == nullptr) {
DenseTensor key_out_owner;
int64_t key_out_owner_size = n;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接用n不就可以

template <typename KT, typename VT>
static void RadixSortPairs(const phi::GPUContext &ctx,
const KT *keys_in,
KT *keys_out,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

}

template <typename KT, typename VT>
static void RadixSortPairs(const phi::GPUContext &ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数跟RadixSortPairsImpl有什么区别吗?好像没看到什么别的处理


template <typename KT>
static void RadixSortKeys(const KT *keys_in,
KT *keys_out,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上


T *sorted_out_ptr;
IndexType *sorted_indices_ptr;
const T *inp = input->data<T>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个命名也改下inp->input_data
下同

const int64_t nsort,
const int64_t n,
const bool descending,
const T *const self_ptr,
Copy link
Contributor

@ZzSean ZzSean Nov 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参数位置调整,一般是:context,输入,输出,其他参数
其他函数有类似问题的也改下

@@ -1,3 +1,4 @@

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

空行删掉

const IndexType num_cols,
const bool descending) {
auto gpu_stream = ctx.stream();
DenseTensor input_indices;
Copy link
Contributor

@ZzSean ZzSean Nov 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个tensor在你的函数里没用到啊

Copy link
Contributor

@ZzSean ZzSean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZzSean ZzSean merged commit 9e9b705 into PaddlePaddle:develop Nov 29, 2022
lxsbupt pushed a commit to lxsbupt/Paddle that referenced this pull request Dec 17, 2022
)

Optimize the implementation of the argsort operator
Vvsmile added a commit to Vvsmile/Paddle that referenced this pull request May 9, 2023
Xreki pushed a commit that referenced this pull request May 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants