Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.34】优化 poisson op #45160

Merged
merged 9 commits into from
Aug 24, 2022
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 19 additions & 26 deletions paddle/phi/kernels/gpu/poisson_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,48 +27,41 @@ limitations under the License. */
namespace phi {

template <typename T>
struct PoissonCudaFunctor {
public:
PoissonCudaFunctor(const T* in,
T* out,
unsigned int seed,
unsigned int offset)
: in_(in), out_(out), seed_(seed), offset_(offset) {}

__device__ void operator()(int64_t idx) {
__global__ void get_poisson(
Rayman96 marked this conversation as resolved.
Show resolved Hide resolved
const int N, const T* in, T* out, unsigned int seed, unsigned int offset) {
Rayman96 marked this conversation as resolved.
Show resolved Hide resolved
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < N) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要写成循环,防止numel过大时,无法完成所有的计算,用CUDA_KERNEL_LOOP_TYPE包一下就可以

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的 我修改下

//
// curandStatePhilox4_32_10_t state;
Rayman96 marked this conversation as resolved.
Show resolved Hide resolved
// curand_init(seed, idx, offset, &state);
// out[idx] = curand_poisson(&state, in[idx]);
#ifdef __NVCC__
curandStatePhilox4_32_10_t state;
curand_init(seed_, idx, offset_, &state);
out_[idx] = static_cast<T>(curand_poisson(&state, in_[idx]));
curand_init(seed, idx, offset, &state);
out[idx] = static_cast<T>(curand_poisson(&state, in[idx]));
#elif __HIPCC__
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed_, idx, offset_, &state);
out_[idx] = static_cast<T>(hiprand_poisson(&state, in_[idx]));
hiprand_init(seed, idx, offset, &state);
out[idx] = static_cast<T>(hiprand_poisson(&state, in[idx]));
#endif
}

private:
const T* in_;
T* out_;
const unsigned int seed_;
const unsigned int offset_;
};
}

template <typename T, typename Context>
void PoissonKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
const T* x_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(out);
auto size = x.numel();
const int size = x.numel();

int block_size = std::min(256, ctx.GetMaxThreadsPerBlock());
Rayman96 marked this conversation as resolved.
Show resolved Hide resolved
dim3 dim_block(block_size);
dim3 dim_grid((size + block_size - 1) / block_size);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grid的设置也需要进行最大值的限制,可以参考

paddle::platform::LimitGridDim(ctx, &grid_dim);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grid的设置也需要进行最大值的限制,可以参考

paddle::platform::LimitGridDim(ctx, &grid_dim);

好的 我修改下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


auto gen_cuda = ctx.GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(20);
uint64_t seed = seed_offset.first;
uint64_t offset = seed_offset.second;

phi::funcs::ForRange<Context> for_range(ctx, size);

PoissonCudaFunctor<T> functor(x_data, out_data, seed, offset);
for_range(functor);
get_poisson<T><<<dim_grid, dim_block>>>(size, x_data, out_data, seed, offset);
}

} // namespace phi
Expand Down