-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize flip kernel by eliminating H2D data transfer, test=develop #46046
Conversation
8f4a035
to
808b3b3
Compare
std::vector<int> flip_dims = axis; | ||
|
||
template <typename T, typename Context, size_t N> | ||
void launch_flip_cuda_kernel(const Context& dev_ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
launch_flip_cuda_kernel -> LaunchFlipCudaKernel,函数命名是大驼峰式。
|
||
int block_size = 512; | ||
dim3 dim_block(block_size); | ||
dim3 dim_grid((N + block_size - 1) / block_size); | ||
dim3 dim_grid((numel + block_size - 1) / block_size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
线程配置,可以调用phi::backends::gpu::GetGpuLaunchConfig1D
函数
|
||
namespace phi { | ||
|
||
template <typename T> | ||
template <typename T, size_t Rank> | ||
__global__ void flip_cuda_kernel(const int N, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后续将N改成int64_t
类型吧
auto x_dims = x.dims(); | ||
const int total_dims = x_dims.size(); | ||
const int N = x.numel(); | ||
const int numel = x.numel(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numel
建议使用int64_t
类型
for (size_t idx = 0; idx < N; ++idx) { | ||
stride_a[idx] = x_stride[idx]; | ||
shape_a[idx] = x_dims[idx]; | ||
flip_dims_a[idx] = idx < flip_dims_size ? flip_dims_v[idx] : 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flip_dims_v
也没有必要吧,直接写入flip_dims_a
就行?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM,review建议可以下个pr再修改
PR types
Performance optimization
PR changes
OPs
Describe
Test environment
pytorch 1.12.1+102
paddle 2.3+102
cuda 11.2
integration times 1000
dype float32
shape [100,1785]
axis 1
Test Result