From 674f58a696360cc9d9ab367a3b446936fd091c06 Mon Sep 17 00:00:00 2001 From: chunhuanMeng Date: Mon, 24 Jun 2024 02:45:18 +0000 Subject: [PATCH 01/16] add AveragePool2d op --- src/aten/AveragePool2d.cpp | 311 +++++++++ src/aten/XPUFallback.template | 2 - src/aten/sycl/AveragePool2dKernels.cpp | 845 +++++++++++++++++++++++++ src/aten/sycl/AveragePool2dKernels.h | 29 + yaml/xpu_functions.yaml | 4 + 5 files changed, 1189 insertions(+), 2 deletions(-) create mode 100644 src/aten/AveragePool2d.cpp create mode 100644 src/aten/sycl/AveragePool2dKernels.cpp create mode 100644 src/aten/sycl/AveragePool2dKernels.h diff --git a/src/aten/AveragePool2d.cpp b/src/aten/AveragePool2d.cpp new file mode 100644 index 000000000..55f545b46 --- /dev/null +++ b/src/aten/AveragePool2d.cpp @@ -0,0 +1,311 @@ +#include +#include +#include +#include +#include + +namespace at { +using namespace at::native; +using namespace at::native::xpu; + +Tensor& avg_pool2d_meta( + const Tensor& input, + Tensor& output, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + std::optional divisor_override) { + TORCH_CHECK( + kernel_size.size() == 1 || kernel_size.size() == 2, + "avg_pool2d: kernel_size must either be a single int, or a tuple " + "of two ints"); + const int64_t kH = kernel_size[0]; + const int64_t kW = kernel_size.size() == 1 ? kH : kernel_size[1]; + + TORCH_CHECK( + stride.empty() || stride.size() == 1 || stride.size() == 2, + "avg_pool2d: stride must either be omitted, a single int, or a " + "tuple of two ints"); + const int64_t dH = stride.empty() ? kH : stride[0]; + const int64_t dW = stride.empty() ? kW : stride.size() == 1 ? dH : stride[1]; + + TORCH_CHECK( + padding.size() == 1 || padding.size() == 2, + "avg_pool2d: padding must either be a single int, or a tuple of " + "two ints"); + const int64_t padH = padding[0]; + const int64_t padW = padding.size() == 1 ? padH : padding[1]; + + TORCH_CHECK( + !divisor_override.has_value() || divisor_override.value() != 0, + "divisor must be not zero"); + + const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1; + const int64_t nInputPlane = input.size(-3); + const int64_t inputHeight = input.size(-2); + const int64_t inputWidth = input.size(-1); + + const int64_t outputHeight = + pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode); + const int64_t outputWidth = + pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode); + + auto memory_format = input.suggest_memory_format(); + pool2d_shape_check( + input, + kH, + kW, + dH, + dW, + padH, + padW, + 1, + 1, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + memory_format); + /* resize output */ + if (input.ndimension() == 3) { + if (output.defined()) { + at::xpu::resize_out( + output, + {nInputPlane, outputHeight, outputWidth}, + {}, + input.options()); + } else { + output = at::xpu::create_out( + {nInputPlane, outputHeight, outputWidth}, {}, input.options()); + } + } else { + if (output.defined()) { + at::xpu::resize_out( + output, + {nbatch, nInputPlane, outputHeight, outputWidth}, + {}, + input.options().memory_format(memory_format)); + } else { + output = at::xpu::create_out( + {nbatch, nInputPlane, outputHeight, outputWidth}, + {}, + input.options().memory_format(memory_format)); + } + } + + return output; +} + +Tensor& avg_pool2d_backward_meta( + const Tensor& gradOutput_, + Tensor& grad_input, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + std::optional divisor_override) { + TORCH_CHECK( + kernel_size.size() == 1 || kernel_size.size() == 2, + "avg_pool2d: kernel_size must either be a single int, or a tuple " + "of two ints"); + const int kH = safe_downcast(kernel_size[0]); + const int kW = kernel_size.size() == 1 + ? kH + : safe_downcast(kernel_size[1]); + + TORCH_CHECK( + stride.empty() || stride.size() == 1 || stride.size() == 2, + "avg_pool2d: stride must either be omitted, a single int, or a " + "tuple of two ints"); + const int dH = stride.empty() ? kH : safe_downcast(stride[0]); + const int dW = stride.empty() ? kW + : stride.size() == 1 ? dH + : safe_downcast(stride[1]); + + TORCH_CHECK( + padding.size() == 1 || padding.size() == 2, + "avg_pool2d: padding must either be a single int, or a tuple of " + "two ints"); + const int padH = safe_downcast(padding[0]); + const int padW = + padding.size() == 1 ? padH : safe_downcast(padding[1]); + + TORCH_CHECK( + !divisor_override.has_value() || divisor_override.value() != 0, + "divisor must be not zero"); + + /* sizes */ + const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1; + const int64_t nInputPlane = input.size(-3); // number of channels (or colors) + const int64_t inputHeight = input.size(-2); + const int64_t inputWidth = input.size(-1); + const int64_t outputWidth = + pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode); + const int64_t outputHeight = + pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode); + + auto memory_format = input.suggest_memory_format(); + avg_pool2d_backward_shape_check( + input, + gradOutput_, + nbatch, + kH, + kW, + dH, + dW, + padH, + padW, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + memory_format); + + if (grad_input.defined()) { + at::xpu::resize_out( + grad_input, + input.sizes(), + {}, + input.options().memory_format(memory_format)); + } else { + grad_input = at::xpu::create_out( + input.sizes(), {}, input.options().memory_format(memory_format)); + } + return grad_input; +} + +Tensor XPUNativeFunctions::avg_pool2d( + const at::Tensor& input, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, + at::IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override) { + Tensor output; + output = avg_pool2d_meta( + input, + output, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override); + + at::native::xpu::avg_pool2d_kernel( + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + output); + return output; +} + +Tensor& XPUNativeFunctions::avg_pool2d_out( + const at::Tensor& input, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, + at::IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override, + Tensor& output) { + avg_pool2d_meta( + input, + output, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override); + + at::native::xpu::avg_pool2d_kernel( + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + output); + return output; +} + +Tensor XPUNativeFunctions::avg_pool2d_backward( + const at::Tensor& grad_output, + const at::Tensor& input, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, + at::IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override) { + Tensor grad_input; + grad_input = avg_pool2d_backward_meta( + grad_output, + grad_input, + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override); + at::native::xpu::avg_pool2d_backward_kernel( + grad_output, + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + grad_input); + return grad_input; +} + +Tensor& XPUNativeFunctions::avg_pool2d_backward_out( + const at::Tensor& grad_output, + const at::Tensor& input, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, + at::IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override, + Tensor& grad_input) { + avg_pool2d_backward_meta( + grad_output, + grad_input, + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override); + at::native::xpu::avg_pool2d_backward_kernel( + grad_output, + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + grad_input); + return grad_input; +} + +} // namespace at \ No newline at end of file diff --git a/src/aten/XPUFallback.template b/src/aten/XPUFallback.template index b844ad42a..94631152e 100644 --- a/src/aten/XPUFallback.template +++ b/src/aten/XPUFallback.template @@ -180,8 +180,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "atan2.out", "atanh.out", "atan.out", - "avg_pool2d_backward.grad_input", - "avg_pool2d.out", "avg_pool3d_backward.grad_input", "avg_pool3d.out", "binary_cross_entropy", diff --git a/src/aten/sycl/AveragePool2dKernels.cpp b/src/aten/sycl/AveragePool2dKernels.cpp new file mode 100644 index 000000000..03a705f81 --- /dev/null +++ b/src/aten/sycl/AveragePool2dKernels.cpp @@ -0,0 +1,845 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace xpu { + +inline int min(int a, int b) { + return a <= b ? a : b; +} + +inline int max(int a, int b) { + return a >= b ? a : b; +} + +template +struct AvgPool2dFrameKernelFunctor { + void operator()(sycl::nd_item<1> item) const { + auto index = item.get_global_linear_id(); + + if (index < total_elements) { + const int pw = index % pooled_width; + const int ph = (index / pooled_width) % pooled_height; + const int c = (index / pooled_width / pooled_height) % channels; + const int n = index / pooled_width / pooled_height / channels; + + int hstart = ph * stride_h - pad_h; + int wstart = pw * stride_w - pad_w; + int hend = min(hstart + kernel_h, height + pad_h); + int wend = min(wstart + kernel_w, width + pad_w); + const int pool_size = (hend - hstart) * (wend - wstart); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + hend = min(hend, height); + wend = min(wend, width); + + if (hstart >= hend || wstart >= wend) { + top_data[index] = scalar_t(0); + return; + } + + accscalar_t aveval = accscalar_t(0); + const scalar_t* const bottom_slice = + bottom_data + (n * channels + c) * height * width; + + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + aveval += bottom_slice[h * width + w]; + } + } + int divide_factor; + if (use_divisor) { + divide_factor = divisor_override; + } else { + if (count_include_pad) { + divide_factor = pool_size; + } else { + divide_factor = (hend - hstart) * (wend - wstart); + } + } + top_data[index] = static_cast(aveval / divide_factor); + } + } + AvgPool2dFrameKernelFunctor( + scalar_t* top_data_, + const scalar_t* bottom_data_, + int64_t total_elements_, + int64_t channels_, + int64_t height_, + int64_t width_, + int pooled_height_, + int pooled_width_, + int kernel_h_, + int kernel_w_, + int stride_h_, + int stride_w_, + int pad_h_, + int pad_w_, + int divisor_override_, + bool count_include_pad_, + bool use_divisor_) + : top_data(top_data_), + bottom_data(bottom_data_), + total_elements(total_elements_), + channels(channels_), + height(height_), + width(width_), + pooled_height(pooled_height_), + pooled_width(pooled_width_), + kernel_h(kernel_h_), + kernel_w(kernel_w_), + stride_h(stride_h_), + stride_w(stride_w_), + pad_h(pad_h_), + pad_w(pad_w_), + divisor_override(divisor_override_), + count_include_pad(count_include_pad_), + use_divisor(use_divisor_) {} + + private: + scalar_t* top_data; + const scalar_t* bottom_data; + int64_t total_elements; + int64_t channels; + int64_t height; + int64_t width; + int pooled_height; + int pooled_width; + int kernel_h; + int kernel_w; + int stride_h; + int stride_w; + int pad_h; + int pad_w; + int divisor_override; + bool count_include_pad; + bool use_divisor; +}; + +template +struct AvgPool2dChannelsLastFrameKernelFunctor { + void operator()(sycl::nd_item<1> item) const { + auto index = item.get_global_linear_id(); + + if (index < total_elements) { + const int c = index % channels; + const int pw = (index / channels) % pooled_width; + const int ph = (index / channels / pooled_width) % pooled_height; + const int n = index / channels / pooled_width / pooled_height; + int hstart = ph * stride_h - pad_h; + int wstart = pw * stride_w - pad_w; + int hend = min(hstart + kernel_h, height + pad_h); + int wend = min(wstart + kernel_w, width + pad_w); + const int pool_size = (hend - hstart) * (wend - wstart); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + hend = min(hend, height); + wend = min(wend, width); + + if (hstart >= hend || wstart >= wend) { + top_data[index] = scalar_t(0); + return; + } + + accscalar_t aveval = accscalar_t(0); + const scalar_t* const bottom_slice = + bottom_data + n * channels * height * width + c; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + aveval += bottom_slice[(h * width + w) * channels]; + } + } + int divide_factor; + if (use_divisor) { + divide_factor = divisor_override; + } else { + if (count_include_pad) { + divide_factor = pool_size; + } else { + divide_factor = (hend - hstart) * (wend - wstart); + } + } + top_data[index] = static_cast(aveval / divide_factor); + } + } + AvgPool2dChannelsLastFrameKernelFunctor( + scalar_t* top_data_, + const scalar_t* bottom_data_, + int64_t total_elements_, + int64_t channels_, + int64_t height_, + int64_t width_, + int pooled_height_, + int pooled_width_, + int kernel_h_, + int kernel_w_, + int stride_h_, + int stride_w_, + int pad_h_, + int pad_w_, + int divisor_override_, + bool count_include_pad_, + bool use_divisor_) + : top_data(top_data_), + bottom_data(bottom_data_), + total_elements(total_elements_), + channels(channels_), + height(height_), + width(width_), + pooled_height(pooled_height_), + pooled_width(pooled_width_), + kernel_h(kernel_h_), + kernel_w(kernel_w_), + stride_h(stride_h_), + stride_w(stride_w_), + pad_h(pad_h_), + pad_w(pad_w_), + divisor_override(divisor_override_), + count_include_pad(count_include_pad_), + use_divisor(use_divisor_) {} + + private: + scalar_t* top_data; + const scalar_t* bottom_data; + int64_t total_elements; + int64_t channels; + int64_t height; + int64_t width; + int pooled_height; + int pooled_width; + int kernel_h; + int kernel_w; + int stride_h; + int stride_w; + int pad_h; + int pad_w; + int divisor_override; + bool count_include_pad; + bool use_divisor; +}; + +template +void avg_pool2d_channels_last_frame( + const int total_elements, + const Tensor& input, + const int64_t channels, + const int64_t height, + const int64_t width, + const int pooled_height, + const int pooled_width, + const int kernel_h, + const int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + Tensor& output, + const int divisor_override, + const bool count_include_pad, + const bool use_divisor) { + scalar_t* top_data = output.data_ptr(); + const scalar_t* bottom_data = input.data_ptr(); + + auto& queue = at::xpu::getCurrentSYCLQueue(); + const uint32_t group_size = + std::min(static_cast(syclMaxWorkItemsPerEU()), 1024); + const uint32_t global_range = + ceil_div(total_elements, group_size) * group_size; + + auto caller = AvgPool2dChannelsLastFrameKernelFunctor( + top_data, + bottom_data, + total_elements, + channels, + height, + width, + pooled_height, + pooled_width, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + divisor_override, + count_include_pad, + use_divisor); + sycl_kernel_submit(global_range, group_size, queue, caller); +} +template +void avg_pool2d_frame( + const int total_elements, + const Tensor& input, + const int64_t channels, + const int64_t height, + const int64_t width, + const int pooled_height, + const int pooled_width, + const int kernel_h, + const int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + Tensor& output, + const int divisor_override, + const bool count_include_pad, + const bool use_divisor) { + scalar_t* top_data = output.data_ptr(); + const scalar_t* bottom_data = input.data_ptr(); + + auto& queue = at::xpu::getCurrentSYCLQueue(); + const uint32_t group_size = + std::min(static_cast(syclMaxWorkItemsPerEU()), 1024); + const uint32_t global_range = + ceil_div(total_elements, group_size) * group_size; + + auto caller = AvgPool2dFrameKernelFunctor( + top_data, + bottom_data, + total_elements, + channels, + height, + width, + pooled_height, + pooled_width, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + divisor_override, + count_include_pad, + use_divisor); + sycl_kernel_submit(global_range, group_size, queue, caller); +} + +template +struct AvgPool2dChannelsLastBackwardKernelFunctor { + void operator()(sycl::nd_item<1> item) const { + index_t index = item.get_global_linear_id(); + if (index < total_elements) { + const int c = index % channels; + const int w = (index / channels) % width + pad_w; + const int h = (index / channels / width) % height + pad_h; + const int n = index / channels / width / height; + const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1; + const int phend = min(h / stride_h + 1, pooled_height); + const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; + const int pwend = min(w / stride_w + 1, pooled_width); + accscalar_t gradient = accscalar_t(0); + const scalar_t* const top_slice = + top_data + n * channels * pooled_height * pooled_width + c; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + // figure out the pooling size + int hstart = ph * stride_h - pad_h; + int wstart = pw * stride_w - pad_w; + int hend = min(hstart + kernel_h, height + pad_h); + int wend = min(wstart + kernel_w, width + pad_w); + int pool_size = (hend - hstart) * (wend - wstart); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + hend = min(hend, height); + wend = min(wend, width); + if (hstart >= hend || wstart >= wend) { + continue; + } + int divide_factor; + if (use_divisor) { + divide_factor = divisor_override; + } else { + if (count_include_pad) { + divide_factor = pool_size; + } else { + divide_factor = (hend - hstart) * (wend - wstart); + } + } + gradient += + top_slice[(ph * pooled_width + pw) * channels] / divide_factor; + } + } + bottom_data[index] = static_cast(gradient); + } + } + AvgPool2dChannelsLastBackwardKernelFunctor( + const scalar_t* top_data_, + scalar_t* bottom_data_, + int64_t total_elements_, + int64_t channels_, + int64_t height_, + int64_t width_, + int pooled_height_, + int pooled_width_, + int kernel_h_, + int kernel_w_, + int stride_h_, + int stride_w_, + int pad_h_, + int pad_w_, + int divisor_override_, + bool count_include_pad_, + bool use_divisor_) + : top_data(top_data_), + bottom_data(bottom_data_), + total_elements(total_elements_), + channels(channels_), + height(height_), + width(width_), + pooled_height(pooled_height_), + pooled_width(pooled_width_), + kernel_h(kernel_h_), + kernel_w(kernel_w_), + stride_h(stride_h_), + stride_w(stride_w_), + pad_h(pad_h_), + pad_w(pad_w_), + divisor_override(divisor_override_), + count_include_pad(count_include_pad_), + use_divisor(use_divisor_) {} + + private: + const scalar_t* top_data; + scalar_t* bottom_data; + int64_t total_elements; + int64_t channels; + int64_t height; + int64_t width; + int pooled_height; + int pooled_width; + int kernel_h; + int kernel_w; + int stride_h; + int stride_w; + int pad_h; + int pad_w; + int divisor_override; + bool count_include_pad; + bool use_divisor; +}; + +template +struct AvgPool2dBackwarKernelFunctor { + void operator()(sycl::nd_item<1> item) const { + index_t index = item.get_global_linear_id(); + if (index < total_elements) { + // find out the local index + // find out the local offset + const int w = index % width + pad_w; + const int h = (index / width) % height + pad_h; + const int c = (index / width / height) % channels; + const int n = index / width / height / channels; + const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1; + const int phend = min(h / stride_h + 1, pooled_height); + const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; + const int pwend = min(w / stride_w + 1, pooled_width); + accscalar_t gradient = accscalar_t(0); + const scalar_t* const top_data_slice = + top_data + (n * channels + c) * pooled_height * pooled_width; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + // figure out the pooling size + int hstart = ph * stride_h - pad_h; + int wstart = pw * stride_w - pad_w; + int hend = min(hstart + kernel_h, height + pad_h); + int wend = min(wstart + kernel_w, width + pad_w); + int pool_size = (hend - hstart) * (wend - wstart); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + hend = min(hend, height); + wend = min(wend, width); + if (hstart >= hend || wstart >= wend) { + continue; + } + int divide_factor; + if (use_divisor) { + divide_factor = divisor_override; + } else { + if (count_include_pad) { + divide_factor = pool_size; + } else { + divide_factor = (hend - hstart) * (wend - wstart); + } + } + gradient += top_data_slice[ph * pooled_width + pw] / divide_factor; + } + } + bottom_data[index] = static_cast(gradient); + } + } + AvgPool2dBackwarKernelFunctor( + const scalar_t* top_data_, + scalar_t* bottom_data_, + int64_t total_elements_, + int64_t channels_, + int64_t height_, + int64_t width_, + int pooled_height_, + int pooled_width_, + int kernel_h_, + int kernel_w_, + int stride_h_, + int stride_w_, + int pad_h_, + int pad_w_, + int divisor_override_, + bool count_include_pad_, + bool use_divisor_) + : top_data(top_data_), + bottom_data(bottom_data_), + total_elements(total_elements_), + channels(channels_), + height(height_), + width(width_), + pooled_height(pooled_height_), + pooled_width(pooled_width_), + kernel_h(kernel_h_), + kernel_w(kernel_w_), + stride_h(stride_h_), + stride_w(stride_w_), + pad_h(pad_h_), + pad_w(pad_w_), + divisor_override(divisor_override_), + count_include_pad(count_include_pad_), + use_divisor(use_divisor_) {} + + private: + const scalar_t* top_data; + scalar_t* bottom_data; + int64_t total_elements; + int64_t channels; + int64_t height; + int64_t width; + int pooled_height; + int pooled_width; + int kernel_h; + int kernel_w; + int stride_h; + int stride_w; + int pad_h; + int pad_w; + int divisor_override; + bool count_include_pad; + bool use_divisor; +}; + +template +void avg_pool2d_backward_channels_last_frame( + const index_t total_elements, + const Tensor& grad_output, + const int64_t channels, + const int64_t height, + const int64_t width, + const int pooled_height, + const int pooled_width, + const int kernel_h, + const int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + Tensor& grad_input, + const int divisor_override, + bool count_include_pad, + bool use_divisor) { + const scalar_t* top_data = grad_output.data_ptr(); + scalar_t* bottom_data = grad_input.data_ptr(); + + auto& queue = at::xpu::getCurrentSYCLQueue(); + const uint32_t group_size = + std::min(static_cast(syclMaxWorkItemsPerEU()), 1024); + const uint32_t global_range = + ceil_div(total_elements, group_size) * group_size; + + auto caller = AvgPool2dChannelsLastBackwardKernelFunctor< + scalar_t, + accscalar_t, + index_t>( + top_data, + bottom_data, + total_elements, + channels, + height, + width, + pooled_height, + pooled_width, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + divisor_override, + count_include_pad, + use_divisor); + sycl_kernel_submit(global_range, group_size, queue, caller); +} + +template +void avg_pool2d_backward_frame( + const index_t total_elements, + const Tensor& grad_output, + const int64_t channels, + const int64_t height, + const int64_t width, + const int pooled_height, + const int pooled_width, + const int kernel_h, + const int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + Tensor& grad_input, + const int divisor_override, + bool count_include_pad, + bool use_divisor) { + const scalar_t* top_data = grad_output.data_ptr(); + scalar_t* bottom_data = grad_input.data_ptr(); + + auto& queue = at::xpu::getCurrentSYCLQueue(); + const uint32_t group_size = + std::min(static_cast(syclMaxWorkItemsPerEU()), 1024); + const uint32_t global_range = + ceil_div(total_elements, group_size) * group_size; + + auto caller = AvgPool2dBackwarKernelFunctor( + top_data, + bottom_data, + total_elements, + channels, + height, + width, + pooled_height, + pooled_width, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + divisor_override, + count_include_pad, + use_divisor); + sycl_kernel_submit(global_range, group_size, queue, caller); +} + +void avg_pool2d_kernel( + const Tensor& input_, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override, + Tensor& output) { + const int kH = safe_downcast(kernel_size[0]); + const int kW = kernel_size.size() == 1 + ? kH + : safe_downcast(kernel_size[1]); + + const int dH = stride.empty() ? kH : safe_downcast(stride[0]); + const int dW = stride.empty() ? kW + : stride.size() == 1 ? dH + : safe_downcast(stride[1]); + + const int padH = safe_downcast(padding[0]); + const int padW = + padding.size() == 1 ? padH : safe_downcast(padding[1]); + + const int64_t nInputPlane = input_.size(-3); + const int64_t inputHeight = input_.size(-2); + const int64_t inputWidth = input_.size(-1); + + int64_t outputWidth = + pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode); + int64_t outputHeight = + pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode); + const auto memory_format = input_.suggest_memory_format(); + + Tensor input = input_.contiguous(memory_format); + const auto count = safe_downcast(output.numel()); + + bool use_divisor = divisor_override.has_value(); + const auto divisor_override_value = + use_divisor ? divisor_override.value() : 0; + if (count != 0) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, input.scalar_type(), "avg_pool2d_out_xpu", [&] { + using accscalar_t = acc_type; + + switch (memory_format) { + case MemoryFormat::ChannelsLast: { + output.unsafeGetTensorImpl()->empty_tensor_restride( + MemoryFormat::ChannelsLast); + avg_pool2d_channels_last_frame( + count, + input, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH, + kW, + dH, + dW, + padH, + padW, + output, + divisor_override_value, + count_include_pad, + use_divisor); + break; + } + case MemoryFormat::Contiguous: { + avg_pool2d_frame( + count, + input, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH, + kW, + dH, + dW, + padH, + padW, + output, + divisor_override_value, + count_include_pad, + use_divisor); + break; + } + default: + TORCH_CHECK( + false, + "Unsupported memory format. Supports only " + "ChannelsLast, Contiguous"); + } + }); + } +} + +void avg_pool2d_backward_kernel( + const Tensor& gradOutput_, + const Tensor& input_, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override, + Tensor& gradInput) { + const int kH = safe_downcast(kernel_size[0]); + const int kW = kernel_size.size() == 1 + ? kH + : safe_downcast(kernel_size[1]); + + const int dH = stride.empty() ? kH : safe_downcast(stride[0]); + const int dW = stride.empty() ? kW + : stride.size() == 1 ? dH + : safe_downcast(stride[1]); + + const int padH = safe_downcast(padding[0]); + const int padW = + padding.size() == 1 ? padH : safe_downcast(padding[1]); + + const auto memory_format = input_.suggest_memory_format(); + const Tensor input = input_.contiguous(memory_format); + const Tensor gradOutput = gradOutput_.contiguous(memory_format); + + const int64_t nInputPlane = input_.size(-3); + const int64_t inputHeight = input_.size(-2); + const int64_t inputWidth = input_.size(-1); + + int64_t outputWidth = + pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode); + int64_t outputHeight = + pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode); + + const auto count = input.numel(); + if (count == 0) { + return; + } + bool use_divisor = divisor_override.has_value(); + const auto divisor_override_value = + use_divisor ? divisor_override.value() : 0; + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, input.scalar_type(), "avg_pool2d_backward_xpu", [&] { + using accscalar_t = acc_type; + + AT_DISPATCH_INDEX_TYPES( + at::native::canUse32BitIndexMath(input, INT_MAX) ? ScalarType::Int + : ScalarType::Long, + "avg_pool2d_backward_xpu_launcher", + [&] { + switch (memory_format) { + case MemoryFormat::ChannelsLast: { + gradInput.unsafeGetTensorImpl()->empty_tensor_restride( + MemoryFormat::ChannelsLast); + avg_pool2d_backward_channels_last_frame< + scalar_t, + accscalar_t, + index_t>( + count, + gradOutput, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH, + kW, + dH, + dW, + padH, + padW, + gradInput, + divisor_override_value, + count_include_pad, + use_divisor); + break; + } + case MemoryFormat::Contiguous: { + avg_pool2d_backward_frame( + count, + gradOutput, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH, + kW, + dH, + dW, + padH, + padW, + gradInput, + divisor_override_value, + count_include_pad, + use_divisor); + break; + } + default: + TORCH_CHECK( + false, + "Unsupported memory format. Supports only " + "ChannelsLast, Contiguous"); + } + }); + }); +} + +} // namespace xpu +} // namespace at::native diff --git a/src/aten/sycl/AveragePool2dKernels.h b/src/aten/sycl/AveragePool2dKernels.h new file mode 100644 index 000000000..b53f23593 --- /dev/null +++ b/src/aten/sycl/AveragePool2dKernels.h @@ -0,0 +1,29 @@ +#include +#include +namespace at::native { +namespace xpu { + +void avg_pool2d_kernel( + const Tensor& input_, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override, + Tensor& output); + +void avg_pool2d_backward_kernel( + const Tensor& gradOutput_, + const Tensor& input_, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override, + Tensor& gradInput); + +} // namespace xpu + +} // namespace at::native \ No newline at end of file diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index a2227c7fa..a19db9240 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -334,3 +334,7 @@ supported: - sgn_ - _pin_memory - is_pinned + - avg_pool2d + - avg_pool2d.out + - avg_pool2d_backward + - avg_pool2d_backward.grad_input \ No newline at end of file From d89e1f3e2167b2d8d671adde66703e7bd2b099bb Mon Sep 17 00:00:00 2001 From: chunhuanMeng Date: Mon, 24 Jun 2024 02:54:45 +0000 Subject: [PATCH 02/16] undo for merge --- src/aten/XPUFallback.template | 2 ++ yaml/xpu_functions.yaml | 6 +----- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/aten/XPUFallback.template b/src/aten/XPUFallback.template index 94631152e..b844ad42a 100644 --- a/src/aten/XPUFallback.template +++ b/src/aten/XPUFallback.template @@ -180,6 +180,8 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "atan2.out", "atanh.out", "atan.out", + "avg_pool2d_backward.grad_input", + "avg_pool2d.out", "avg_pool3d_backward.grad_input", "avg_pool3d.out", "binary_cross_entropy", diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index a19db9240..bdc7c9f7f 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -333,8 +333,4 @@ supported: - sgn.out - sgn_ - _pin_memory - - is_pinned - - avg_pool2d - - avg_pool2d.out - - avg_pool2d_backward - - avg_pool2d_backward.grad_input \ No newline at end of file + - is_pinned \ No newline at end of file From 5ce93eaab138e7f306e6685dcb16ab7edb6de52a Mon Sep 17 00:00:00 2001 From: chunhuanMeng Date: Mon, 24 Jun 2024 03:10:00 +0000 Subject: [PATCH 03/16] move file --- src/{aten => ATen/native/xpu}/AveragePool2d.cpp | 0 src/{aten => ATen/native/xpu}/sycl/AveragePool2dKernels.cpp | 0 src/{aten => ATen/native/xpu}/sycl/AveragePool2dKernels.h | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename src/{aten => ATen/native/xpu}/AveragePool2d.cpp (100%) rename src/{aten => ATen/native/xpu}/sycl/AveragePool2dKernels.cpp (100%) rename src/{aten => ATen/native/xpu}/sycl/AveragePool2dKernels.h (100%) diff --git a/src/aten/AveragePool2d.cpp b/src/ATen/native/xpu/AveragePool2d.cpp similarity index 100% rename from src/aten/AveragePool2d.cpp rename to src/ATen/native/xpu/AveragePool2d.cpp diff --git a/src/aten/sycl/AveragePool2dKernels.cpp b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp similarity index 100% rename from src/aten/sycl/AveragePool2dKernels.cpp rename to src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp diff --git a/src/aten/sycl/AveragePool2dKernels.h b/src/ATen/native/xpu/sycl/AveragePool2dKernels.h similarity index 100% rename from src/aten/sycl/AveragePool2dKernels.h rename to src/ATen/native/xpu/sycl/AveragePool2dKernels.h From bd6bcfec8eca165b1177a6d8f7c1b542e00cf408 Mon Sep 17 00:00:00 2001 From: chunhuanMeng <105194461+chunhuanMeng@users.noreply.github.com> Date: Mon, 24 Jun 2024 11:15:46 +0800 Subject: [PATCH 04/16] Update xpu_functions.yaml --- yaml/xpu_functions.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 7bf1a6eee..069b92e74 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -380,4 +380,5 @@ supported: - acosh_ - acosh.out - addr - - addr.out \ No newline at end of file + - addr.out + From 549960325242c57bc683c48e909cc5302349036d Mon Sep 17 00:00:00 2001 From: chunhuanMeng <105194461+chunhuanMeng@users.noreply.github.com> Date: Mon, 24 Jun 2024 11:16:15 +0800 Subject: [PATCH 05/16] Update xpu_functions.yaml From 01783ef714d9424e4efc946c001236aa7fe93f72 Mon Sep 17 00:00:00 2001 From: chunhuanMeng <105194461+chunhuanMeng@users.noreply.github.com> Date: Mon, 24 Jun 2024 11:16:41 +0800 Subject: [PATCH 06/16] Update xpu_functions.yaml --- yaml/xpu_functions.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 069b92e74..45ad9cc55 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -381,4 +381,3 @@ supported: - acosh.out - addr - addr.out - From 01c28d4d562c08cc394e737090dd42de671c9a8e Mon Sep 17 00:00:00 2001 From: chunhuanMeng Date: Mon, 24 Jun 2024 03:33:39 +0000 Subject: [PATCH 07/16] enable ut for averagepool2d --- src/ATen/native/xpu/XPUFallback.template | 2 -- test/xpu/run_test_with_skip.py | 6 ++++++ test/xpu/xpu_test_utils.py | 1 + yaml/xpu_functions.yaml | 4 ++++ 4 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 98b408dab..f2b30b620 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -176,8 +176,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "atan2.out", "atanh.out", "atan.out", - "avg_pool2d_backward.grad_input", - "avg_pool2d.out", "avg_pool3d_backward.grad_input", "avg_pool3d.out", "binary_cross_entropy", diff --git a/test/xpu/run_test_with_skip.py b/test/xpu/run_test_with_skip.py index 53560faa4..4f4d4204a 100644 --- a/test/xpu/run_test_with_skip.py +++ b/test/xpu/run_test_with_skip.py @@ -147,6 +147,7 @@ def launch_test(test_case, skip_list=None, exe_list=None): "test_out_nn_functional_conv_transpose3d_xpu_float32", "test_out_requires_grad_error_sparse_sampled_addmm_xpu_complex64", "test_out_requires_grad_error_sparse_sampled_addmm_xpu_float32", + "test_out_nn_functional_avg_pool2d_xpu_float32", # CUDA xfail. "test_out_to_sparse_xpu_float32", "test_out_warning__native_batch_norm_legit_xpu", "test_out_warning_jiterator_2inputs_2outputs_xpu", @@ -412,6 +413,11 @@ def launch_test(test_case, skip_list=None, exe_list=None): "test_dtypes_view_as_real_xpu", # Didn't align with CUDA, The following dtypes did not work in backward but are listed by the OpInfo: {torch.bfloat16} "test_python_ref_executor__refs_pow_executor_aten_xpu_complex32", # Didn't align with CUDA, Unexpected success + "test_noncontiguous_samples_nn_functional_avg_pool2d_xpu_int64",# The implementation aligns with CUDA - "avg_pool2d_out_xpu" not implemented for 'Long'. + "test_noncontiguous_samples_nn_functional_avg_pool1d_xpu_int64",# The implementation aligns with CUDA - "avg_pool2d_out_xpu" not implemented for 'Long'. + "test_noncontiguous_samples_nn_functional_local_response_norm_xpu_int64",# The implementation aligns with CUDA - "avg_pool2d_out_xpu" not implemented for 'Long'. + "test_dtypes_nn_functional_avg_pool2d_xpu",# The implementation aligns with CUDA - "avg_pool2d_out_xpu" not implemented for 'Long'. + # https://github.com/intel/torch-xpu-ops/issues/157 # Segfault: "test_dtypes_nn_functional_linear_xpu", # https://github.com/intel/torch-xpu-ops/issues/157 diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 707da75ad..c5a5236d7 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -100,6 +100,7 @@ "scatter", "gather", "max_pool2d_with_indices_backward", + "nn.functional.avg_pool2d", "nn.functional.embedding", "nn.functional.unfold", # "nn.functional.nll_loss", # Lack of XPU implementation of aten::nll_loss2d_forward. Will retrieve the case, only if the op is implemented. diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 45ad9cc55..8620066e2 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -381,3 +381,7 @@ supported: - acosh.out - addr - addr.out + - avg_pool2d + - avg_pool2d.out + - avg_pool2d_backward + - avg_pool2d_backward.grad_input From e9da9dffdd5470b8a25424e76099d33030a09dad Mon Sep 17 00:00:00 2001 From: chunhuanMeng Date: Mon, 24 Jun 2024 03:34:11 +0000 Subject: [PATCH 08/16] add skip ut --- test/xpu/fin_grain/run_fine_grain.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/xpu/fin_grain/run_fine_grain.py b/test/xpu/fin_grain/run_fine_grain.py index ddfd059f0..ec9b8f4b8 100644 --- a/test/xpu/fin_grain/run_fine_grain.py +++ b/test/xpu/fin_grain/run_fine_grain.py @@ -67,6 +67,9 @@ "test_compare_cpu_sigmoid_xpu_complex64", "test_compare_cpu_sigmoid_xpu_complex128", + # Align with CUDA dtypes - RuntimeError: "avg_pool2d_out_xpu" not implemented for 'Long' + "test_compare_cpu_nn_functional_avg_pool2d_xpu_int64", + # Special handle (different calculation order) in CPU reference impl. # https://github.com/pytorch/pytorch/blob/c97e3ebb96d7457075b019b94411e8c2d058e68b/aten/src/ATen/native/EmbeddingBag.cpp#L300 "test_compare_cpu_nn_functional_embedding_bag_xpu_bfloat16", From 90f63c6d8b00da08181c8d368fd37546e3f350a2 Mon Sep 17 00:00:00 2001 From: chunhuanMeng <105194461+chunhuanMeng@users.noreply.github.com> Date: Mon, 24 Jun 2024 15:09:06 +0800 Subject: [PATCH 09/16] Remove spaces --- src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp index 03a705f81..48b45ffcb 100644 --- a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp +++ b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include From 6326a7aaf2f0229ad3e2a730883086b04440b566 Mon Sep 17 00:00:00 2001 From: chunhuanMeng Date: Mon, 24 Jun 2024 07:30:35 +0000 Subject: [PATCH 10/16] correct file location --- src/ATen/native/xpu/AveragePool2d.cpp | 2 +- src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ATen/native/xpu/AveragePool2d.cpp b/src/ATen/native/xpu/AveragePool2d.cpp index 55f545b46..62faaf258 100644 --- a/src/ATen/native/xpu/AveragePool2d.cpp +++ b/src/ATen/native/xpu/AveragePool2d.cpp @@ -1,7 +1,7 @@ #include #include #include -#include +#include #include namespace at { diff --git a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp index 48b45ffcb..8d8c57239 100644 --- a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp +++ b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp @@ -2,7 +2,7 @@ #include #include #include -#include +#include #include #include #include From ce6dc1e2856d3754abebf2c51131299794859e5a Mon Sep 17 00:00:00 2001 From: chunhuanMeng <105194461+chunhuanMeng@users.noreply.github.com> Date: Tue, 25 Jun 2024 09:36:30 +0800 Subject: [PATCH 11/16] header file location --- src/ATen/native/xpu/AveragePool2d.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ATen/native/xpu/AveragePool2d.cpp b/src/ATen/native/xpu/AveragePool2d.cpp index 62faaf258..189ae05e1 100644 --- a/src/ATen/native/xpu/AveragePool2d.cpp +++ b/src/ATen/native/xpu/AveragePool2d.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include #include #include @@ -308,4 +308,4 @@ Tensor& XPUNativeFunctions::avg_pool2d_backward_out( return grad_input; } -} // namespace at \ No newline at end of file +} // namespace at From f61c635d4191464c875eb5b5f755cf9396af8446 Mon Sep 17 00:00:00 2001 From: chunhuanMeng <105194461+chunhuanMeng@users.noreply.github.com> Date: Wed, 26 Jun 2024 15:18:52 +0800 Subject: [PATCH 12/16] Update AveragePool2dKernels.cpp --- src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp index 8d8c57239..4e89ca77b 100644 --- a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp +++ b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp @@ -672,7 +672,7 @@ void avg_pool2d_kernel( use_divisor ? divisor_override.value() : 0; if (count != 0) { AT_DISPATCH_FLOATING_TYPES_AND2( - kHalf, kBFloat16, input.scalar_type(), "avg_pool2d_out_xpu", [&] { + kHalf, kBFloat16, input.scalar_type(), "avg_pool2d_xpu", [&] { using accscalar_t = acc_type; switch (memory_format) { From 19faadae59c0a63480efa7746abcfa9933c1ef6d Mon Sep 17 00:00:00 2001 From: chunhuanMeng Date: Fri, 28 Jun 2024 07:03:03 +0000 Subject: [PATCH 13/16] use syclMaxWorkItemsPerEU() directly --- src/ATen/native/xpu/AveragePool2d.cpp | 14 +- .../native/xpu/sycl/AveragePool2dKernels.cpp | 578 +++++++++--------- 2 files changed, 294 insertions(+), 298 deletions(-) diff --git a/src/ATen/native/xpu/AveragePool2d.cpp b/src/ATen/native/xpu/AveragePool2d.cpp index 189ae05e1..423a37955 100644 --- a/src/ATen/native/xpu/AveragePool2d.cpp +++ b/src/ATen/native/xpu/AveragePool2d.cpp @@ -1,7 +1,7 @@ #include -#include #include #include +#include #include namespace at { @@ -181,7 +181,7 @@ Tensor& avg_pool2d_backward_meta( } Tensor XPUNativeFunctions::avg_pool2d( - const at::Tensor& input, + const Tensor& input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, @@ -212,7 +212,7 @@ Tensor XPUNativeFunctions::avg_pool2d( } Tensor& XPUNativeFunctions::avg_pool2d_out( - const at::Tensor& input, + const Tensor& input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, @@ -243,8 +243,8 @@ Tensor& XPUNativeFunctions::avg_pool2d_out( } Tensor XPUNativeFunctions::avg_pool2d_backward( - const at::Tensor& grad_output, - const at::Tensor& input, + const Tensor& grad_output, + const Tensor& input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, @@ -276,8 +276,8 @@ Tensor XPUNativeFunctions::avg_pool2d_backward( } Tensor& XPUNativeFunctions::avg_pool2d_backward_out( - const at::Tensor& grad_output, - const at::Tensor& input, + const Tensor& grad_output, + const Tensor& input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, diff --git a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp index 4e89ca77b..8f08668cc 100644 --- a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp +++ b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp @@ -23,103 +23,103 @@ struct AvgPool2dFrameKernelFunctor { void operator()(sycl::nd_item<1> item) const { auto index = item.get_global_linear_id(); - if (index < total_elements) { - const int pw = index % pooled_width; - const int ph = (index / pooled_width) % pooled_height; - const int c = (index / pooled_width / pooled_height) % channels; - const int n = index / pooled_width / pooled_height / channels; - - int hstart = ph * stride_h - pad_h; - int wstart = pw * stride_w - pad_w; - int hend = min(hstart + kernel_h, height + pad_h); - int wend = min(wstart + kernel_w, width + pad_w); + if (index < total_elements_) { + const int pw = index % pooled_width_; + const int ph = (index / pooled_width_) % pooled_height_; + const int c = (index / pooled_width_ / pooled_height_) % channels_; + const int n = index / pooled_width_ / pooled_height_ / channels_; + + int hstart = ph * stride_h_ - pad_h_; + int wstart = pw * stride_w_ - pad_w_; + int hend = min(hstart + kernel_h_, height_ + pad_h_); + int wend = min(wstart + kernel_w_, width_ + pad_w_); const int pool_size = (hend - hstart) * (wend - wstart); hstart = max(hstart, 0); wstart = max(wstart, 0); - hend = min(hend, height); - wend = min(wend, width); + hend = min(hend, height_); + wend = min(wend, width_); if (hstart >= hend || wstart >= wend) { - top_data[index] = scalar_t(0); + top_data_[index] = scalar_t(0); return; } accscalar_t aveval = accscalar_t(0); const scalar_t* const bottom_slice = - bottom_data + (n * channels + c) * height * width; + bottom_data_ + (n * channels_ + c) * height_ * width_; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - aveval += bottom_slice[h * width + w]; + aveval += bottom_slice[h * width_ + w]; } } int divide_factor; - if (use_divisor) { - divide_factor = divisor_override; + if (use_divisor_) { + divide_factor = divisor_override_; } else { - if (count_include_pad) { + if (count_include_pad_) { divide_factor = pool_size; } else { divide_factor = (hend - hstart) * (wend - wstart); } } - top_data[index] = static_cast(aveval / divide_factor); + top_data_[index] = static_cast(aveval / divide_factor); } } AvgPool2dFrameKernelFunctor( - scalar_t* top_data_, - const scalar_t* bottom_data_, - int64_t total_elements_, - int64_t channels_, - int64_t height_, - int64_t width_, - int pooled_height_, - int pooled_width_, - int kernel_h_, - int kernel_w_, - int stride_h_, - int stride_w_, - int pad_h_, - int pad_w_, - int divisor_override_, - bool count_include_pad_, - bool use_divisor_) - : top_data(top_data_), - bottom_data(bottom_data_), - total_elements(total_elements_), - channels(channels_), - height(height_), - width(width_), - pooled_height(pooled_height_), - pooled_width(pooled_width_), - kernel_h(kernel_h_), - kernel_w(kernel_w_), - stride_h(stride_h_), - stride_w(stride_w_), - pad_h(pad_h_), - pad_w(pad_w_), - divisor_override(divisor_override_), - count_include_pad(count_include_pad_), - use_divisor(use_divisor_) {} + scalar_t* top_data, + const scalar_t* bottom_data, + int64_t total_elements, + int64_t channels, + int64_t height, + int64_t width, + int pooled_height, + int pooled_width, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int divisor_override, + bool count_include_pad, + bool use_divisor) + : top_data_(top_data), + bottom_data_(bottom_data), + total_elements_(total_elements), + channels_(channels), + height_(height), + width_(width), + pooled_height_(pooled_height), + pooled_width_(pooled_width), + kernel_h_(kernel_h), + kernel_w_(kernel_w), + stride_h_(stride_h), + stride_w_(stride_w), + pad_h_(pad_h), + pad_w_(pad_w), + divisor_override_(divisor_override), + count_include_pad_(count_include_pad), + use_divisor_(use_divisor) {} private: - scalar_t* top_data; - const scalar_t* bottom_data; - int64_t total_elements; - int64_t channels; - int64_t height; - int64_t width; - int pooled_height; - int pooled_width; - int kernel_h; - int kernel_w; - int stride_h; - int stride_w; - int pad_h; - int pad_w; - int divisor_override; - bool count_include_pad; - bool use_divisor; + scalar_t* top_data_; + const scalar_t* bottom_data_; + int64_t total_elements_; + int64_t channels_; + int64_t height_; + int64_t width_; + int pooled_height_; + int pooled_width_; + int kernel_h_; + int kernel_w_; + int stride_h_; + int stride_w_; + int pad_h_; + int pad_w_; + int divisor_override_; + bool count_include_pad_; + bool use_divisor_; }; template @@ -127,101 +127,101 @@ struct AvgPool2dChannelsLastFrameKernelFunctor { void operator()(sycl::nd_item<1> item) const { auto index = item.get_global_linear_id(); - if (index < total_elements) { - const int c = index % channels; - const int pw = (index / channels) % pooled_width; - const int ph = (index / channels / pooled_width) % pooled_height; - const int n = index / channels / pooled_width / pooled_height; - int hstart = ph * stride_h - pad_h; - int wstart = pw * stride_w - pad_w; - int hend = min(hstart + kernel_h, height + pad_h); - int wend = min(wstart + kernel_w, width + pad_w); + if (index < total_elements_) { + const int c = index % channels_; + const int pw = (index / channels_) % pooled_width_; + const int ph = (index / channels_ / pooled_width_) % pooled_height_; + const int n = index / channels_ / pooled_width_ / pooled_height_; + int hstart = ph * stride_h_ - pad_h_; + int wstart = pw * stride_w_ - pad_w_; + int hend = min(hstart + kernel_h_, height_ + pad_h_); + int wend = min(wstart + kernel_w_, width_ + pad_w_); const int pool_size = (hend - hstart) * (wend - wstart); hstart = max(hstart, 0); wstart = max(wstart, 0); - hend = min(hend, height); - wend = min(wend, width); + hend = min(hend, height_); + wend = min(wend, width_); if (hstart >= hend || wstart >= wend) { - top_data[index] = scalar_t(0); + top_data_[index] = scalar_t(0); return; } accscalar_t aveval = accscalar_t(0); const scalar_t* const bottom_slice = - bottom_data + n * channels * height * width + c; + bottom_data_ + n * channels_ * height_ * width_ + c; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - aveval += bottom_slice[(h * width + w) * channels]; + aveval += bottom_slice[(h * width_ + w) * channels_]; } } int divide_factor; - if (use_divisor) { - divide_factor = divisor_override; + if (use_divisor_) { + divide_factor = divisor_override_; } else { - if (count_include_pad) { + if (count_include_pad_) { divide_factor = pool_size; } else { divide_factor = (hend - hstart) * (wend - wstart); } } - top_data[index] = static_cast(aveval / divide_factor); + top_data_[index] = static_cast(aveval / divide_factor); } } AvgPool2dChannelsLastFrameKernelFunctor( - scalar_t* top_data_, - const scalar_t* bottom_data_, - int64_t total_elements_, - int64_t channels_, - int64_t height_, - int64_t width_, - int pooled_height_, - int pooled_width_, - int kernel_h_, - int kernel_w_, - int stride_h_, - int stride_w_, - int pad_h_, - int pad_w_, - int divisor_override_, - bool count_include_pad_, - bool use_divisor_) - : top_data(top_data_), - bottom_data(bottom_data_), - total_elements(total_elements_), - channels(channels_), - height(height_), - width(width_), - pooled_height(pooled_height_), - pooled_width(pooled_width_), - kernel_h(kernel_h_), - kernel_w(kernel_w_), - stride_h(stride_h_), - stride_w(stride_w_), - pad_h(pad_h_), - pad_w(pad_w_), - divisor_override(divisor_override_), - count_include_pad(count_include_pad_), - use_divisor(use_divisor_) {} + scalar_t* top_data, + const scalar_t* bottom_data, + int64_t total_elements, + int64_t channels, + int64_t height, + int64_t width, + int pooled_height, + int pooled_width, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int divisor_override, + bool count_include_pad, + bool use_divisor) + : top_data_(top_data), + bottom_data_(bottom_data), + total_elements_(total_elements), + channels_(channels), + height_(height), + width_(width), + pooled_height_(pooled_height), + pooled_width_(pooled_width), + kernel_h_(kernel_h), + kernel_w_(kernel_w), + stride_h_(stride_h), + stride_w_(stride_w), + pad_h_(pad_h), + pad_w_(pad_w), + divisor_override_(divisor_override), + count_include_pad_(count_include_pad), + use_divisor_(use_divisor) {} private: - scalar_t* top_data; - const scalar_t* bottom_data; - int64_t total_elements; - int64_t channels; - int64_t height; - int64_t width; - int pooled_height; - int pooled_width; - int kernel_h; - int kernel_w; - int stride_h; - int stride_w; - int pad_h; - int pad_w; - int divisor_override; - bool count_include_pad; - bool use_divisor; + scalar_t* top_data_; + const scalar_t* bottom_data_; + int64_t total_elements_; + int64_t channels_; + int64_t height_; + int64_t width_; + int pooled_height_; + int pooled_width_; + int kernel_h_; + int kernel_w_; + int stride_h_; + int stride_w_; + int pad_h_; + int pad_w_; + int divisor_override_; + bool count_include_pad_; + bool use_divisor_; }; template @@ -247,8 +247,7 @@ void avg_pool2d_channels_last_frame( const scalar_t* bottom_data = input.data_ptr(); auto& queue = at::xpu::getCurrentSYCLQueue(); - const uint32_t group_size = - std::min(static_cast(syclMaxWorkItemsPerEU()), 1024); + const uint32_t group_size = static_cast(syclMaxWorkItemsPerEU()); const uint32_t global_range = ceil_div(total_elements, group_size) * group_size; @@ -295,8 +294,7 @@ void avg_pool2d_frame( const scalar_t* bottom_data = input.data_ptr(); auto& queue = at::xpu::getCurrentSYCLQueue(); - const uint32_t group_size = - std::min(static_cast(syclMaxWorkItemsPerEU()), 1024); + const uint32_t group_size = static_cast(syclMaxWorkItemsPerEU()); const uint32_t global_range = ceil_div(total_elements, group_size) * group_size; @@ -325,209 +323,209 @@ template struct AvgPool2dChannelsLastBackwardKernelFunctor { void operator()(sycl::nd_item<1> item) const { index_t index = item.get_global_linear_id(); - if (index < total_elements) { - const int c = index % channels; - const int w = (index / channels) % width + pad_w; - const int h = (index / channels / width) % height + pad_h; - const int n = index / channels / width / height; - const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1; - const int phend = min(h / stride_h + 1, pooled_height); - const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; - const int pwend = min(w / stride_w + 1, pooled_width); + if (index < total_elements_) { + const int c = index % channels_; + const int w = (index / channels_) % width_ + pad_w_; + const int h = (index / channels_ / width_) % height_ + pad_h_; + const int n = index / channels_ / width_ / height_; + const int phstart = (h < kernel_h_) ? 0 : (h - kernel_h_) / stride_h_ + 1; + const int phend = min(h / stride_h_ + 1, pooled_height_); + const int pwstart = (w < kernel_w_) ? 0 : (w - kernel_w_) / stride_w_ + 1; + const int pwend = min(w / stride_w_ + 1, pooled_width_); accscalar_t gradient = accscalar_t(0); const scalar_t* const top_slice = - top_data + n * channels * pooled_height * pooled_width + c; + top_data_ + n * channels_ * pooled_height_ * pooled_width_ + c; for (int ph = phstart; ph < phend; ++ph) { for (int pw = pwstart; pw < pwend; ++pw) { // figure out the pooling size - int hstart = ph * stride_h - pad_h; - int wstart = pw * stride_w - pad_w; - int hend = min(hstart + kernel_h, height + pad_h); - int wend = min(wstart + kernel_w, width + pad_w); + int hstart = ph * stride_h_ - pad_h_; + int wstart = pw * stride_w_ - pad_w_; + int hend = min(hstart + kernel_h_, height_ + pad_h_); + int wend = min(wstart + kernel_w_, width_ + pad_w_); int pool_size = (hend - hstart) * (wend - wstart); hstart = max(hstart, 0); wstart = max(wstart, 0); - hend = min(hend, height); - wend = min(wend, width); + hend = min(hend, height_); + wend = min(wend, width_); if (hstart >= hend || wstart >= wend) { continue; } int divide_factor; - if (use_divisor) { - divide_factor = divisor_override; + if (use_divisor_) { + divide_factor = divisor_override_; } else { - if (count_include_pad) { + if (count_include_pad_) { divide_factor = pool_size; } else { divide_factor = (hend - hstart) * (wend - wstart); } } gradient += - top_slice[(ph * pooled_width + pw) * channels] / divide_factor; + top_slice[(ph * pooled_width_ + pw) * channels_] / divide_factor; } } - bottom_data[index] = static_cast(gradient); + bottom_data_[index] = static_cast(gradient); } } AvgPool2dChannelsLastBackwardKernelFunctor( - const scalar_t* top_data_, - scalar_t* bottom_data_, - int64_t total_elements_, - int64_t channels_, - int64_t height_, - int64_t width_, - int pooled_height_, - int pooled_width_, - int kernel_h_, - int kernel_w_, - int stride_h_, - int stride_w_, - int pad_h_, - int pad_w_, - int divisor_override_, - bool count_include_pad_, - bool use_divisor_) - : top_data(top_data_), - bottom_data(bottom_data_), - total_elements(total_elements_), - channels(channels_), - height(height_), - width(width_), - pooled_height(pooled_height_), - pooled_width(pooled_width_), - kernel_h(kernel_h_), - kernel_w(kernel_w_), - stride_h(stride_h_), - stride_w(stride_w_), - pad_h(pad_h_), - pad_w(pad_w_), - divisor_override(divisor_override_), - count_include_pad(count_include_pad_), - use_divisor(use_divisor_) {} + const scalar_t* top_data, + scalar_t* bottom_data, + int64_t total_elements, + int64_t channels, + int64_t height, + int64_t width, + int pooled_height, + int pooled_width, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int divisor_override, + bool count_include_pad, + bool use_divisor) + : top_data_(top_data), + bottom_data_(bottom_data), + total_elements_(total_elements), + channels_(channels), + height_(height), + width_(width), + pooled_height_(pooled_height), + pooled_width_(pooled_width), + kernel_h_(kernel_h), + kernel_w_(kernel_w), + stride_h_(stride_h), + stride_w_(stride_w), + pad_h_(pad_h), + pad_w_(pad_w), + divisor_override_(divisor_override), + count_include_pad_(count_include_pad), + use_divisor_(use_divisor) {} private: - const scalar_t* top_data; - scalar_t* bottom_data; - int64_t total_elements; - int64_t channels; - int64_t height; - int64_t width; - int pooled_height; - int pooled_width; - int kernel_h; - int kernel_w; - int stride_h; - int stride_w; - int pad_h; - int pad_w; - int divisor_override; - bool count_include_pad; - bool use_divisor; + const scalar_t* top_data_; + scalar_t* bottom_data_; + int64_t total_elements_; + int64_t channels_; + int64_t height_; + int64_t width_; + int pooled_height_; + int pooled_width_; + int kernel_h_; + int kernel_w_; + int stride_h_; + int stride_w_; + int pad_h_; + int pad_w_; + int divisor_override_; + bool count_include_pad_; + bool use_divisor_; }; template struct AvgPool2dBackwarKernelFunctor { void operator()(sycl::nd_item<1> item) const { index_t index = item.get_global_linear_id(); - if (index < total_elements) { + if (index < total_elements_) { // find out the local index // find out the local offset - const int w = index % width + pad_w; - const int h = (index / width) % height + pad_h; - const int c = (index / width / height) % channels; - const int n = index / width / height / channels; - const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1; - const int phend = min(h / stride_h + 1, pooled_height); - const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; - const int pwend = min(w / stride_w + 1, pooled_width); + const int w = index % width_ + pad_w_; + const int h = (index / width_) % height_ + pad_h_; + const int c = (index / width_ / height_) % channels_; + const int n = index / width_ / height_ / channels_; + const int phstart = (h < kernel_h_) ? 0 : (h - kernel_h_) / stride_h_ + 1; + const int phend = min(h / stride_h_ + 1, pooled_height_); + const int pwstart = (w < kernel_w_) ? 0 : (w - kernel_w_) / stride_w_ + 1; + const int pwend = min(w / stride_w_ + 1, pooled_width_); accscalar_t gradient = accscalar_t(0); const scalar_t* const top_data_slice = - top_data + (n * channels + c) * pooled_height * pooled_width; + top_data_ + (n * channels_ + c) * pooled_height_ * pooled_width_; for (int ph = phstart; ph < phend; ++ph) { for (int pw = pwstart; pw < pwend; ++pw) { // figure out the pooling size - int hstart = ph * stride_h - pad_h; - int wstart = pw * stride_w - pad_w; - int hend = min(hstart + kernel_h, height + pad_h); - int wend = min(wstart + kernel_w, width + pad_w); + int hstart = ph * stride_h_ - pad_h_; + int wstart = pw * stride_w_ - pad_w_; + int hend = min(hstart + kernel_h_, height_ + pad_h_); + int wend = min(wstart + kernel_w_, width_ + pad_w_); int pool_size = (hend - hstart) * (wend - wstart); hstart = max(hstart, 0); wstart = max(wstart, 0); - hend = min(hend, height); - wend = min(wend, width); + hend = min(hend, height_); + wend = min(wend, width_); if (hstart >= hend || wstart >= wend) { continue; } int divide_factor; - if (use_divisor) { - divide_factor = divisor_override; + if (use_divisor_) { + divide_factor = divisor_override_; } else { - if (count_include_pad) { + if (count_include_pad_) { divide_factor = pool_size; } else { divide_factor = (hend - hstart) * (wend - wstart); } } - gradient += top_data_slice[ph * pooled_width + pw] / divide_factor; + gradient += top_data_slice[ph * pooled_width_ + pw] / divide_factor; } } - bottom_data[index] = static_cast(gradient); + bottom_data_[index] = static_cast(gradient); } } AvgPool2dBackwarKernelFunctor( - const scalar_t* top_data_, - scalar_t* bottom_data_, - int64_t total_elements_, - int64_t channels_, - int64_t height_, - int64_t width_, - int pooled_height_, - int pooled_width_, - int kernel_h_, - int kernel_w_, - int stride_h_, - int stride_w_, - int pad_h_, - int pad_w_, - int divisor_override_, - bool count_include_pad_, - bool use_divisor_) - : top_data(top_data_), - bottom_data(bottom_data_), - total_elements(total_elements_), - channels(channels_), - height(height_), - width(width_), - pooled_height(pooled_height_), - pooled_width(pooled_width_), - kernel_h(kernel_h_), - kernel_w(kernel_w_), - stride_h(stride_h_), - stride_w(stride_w_), - pad_h(pad_h_), - pad_w(pad_w_), - divisor_override(divisor_override_), - count_include_pad(count_include_pad_), - use_divisor(use_divisor_) {} + const scalar_t* top_data, + scalar_t* bottom_data, + int64_t total_elements, + int64_t channels, + int64_t height, + int64_t width, + int pooled_height, + int pooled_width, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int divisor_override, + bool count_include_pad, + bool use_divisor) + : top_data_(top_data), + bottom_data_(bottom_data), + total_elements_(total_elements), + channels_(channels), + height_(height), + width_(width), + pooled_height_(pooled_height), + pooled_width_(pooled_width), + kernel_h_(kernel_h), + kernel_w_(kernel_w), + stride_h_(stride_h), + stride_w_(stride_w), + pad_h_(pad_h), + pad_w_(pad_w), + divisor_override_(divisor_override), + count_include_pad_(count_include_pad), + use_divisor_(use_divisor) {} private: - const scalar_t* top_data; - scalar_t* bottom_data; - int64_t total_elements; - int64_t channels; - int64_t height; - int64_t width; - int pooled_height; - int pooled_width; - int kernel_h; - int kernel_w; - int stride_h; - int stride_w; - int pad_h; - int pad_w; - int divisor_override; - bool count_include_pad; - bool use_divisor; + const scalar_t* top_data_; + scalar_t* bottom_data_; + int64_t total_elements_; + int64_t channels_; + int64_t height_; + int64_t width_; + int pooled_height_; + int pooled_width_; + int kernel_h_; + int kernel_w_; + int stride_h_; + int stride_w_; + int pad_h_; + int pad_w_; + int divisor_override_; + bool count_include_pad_; + bool use_divisor_; }; template @@ -553,8 +551,7 @@ void avg_pool2d_backward_channels_last_frame( scalar_t* bottom_data = grad_input.data_ptr(); auto& queue = at::xpu::getCurrentSYCLQueue(); - const uint32_t group_size = - std::min(static_cast(syclMaxWorkItemsPerEU()), 1024); + const uint32_t group_size = static_cast(syclMaxWorkItemsPerEU()); const uint32_t global_range = ceil_div(total_elements, group_size) * group_size; @@ -605,8 +602,7 @@ void avg_pool2d_backward_frame( scalar_t* bottom_data = grad_input.data_ptr(); auto& queue = at::xpu::getCurrentSYCLQueue(); - const uint32_t group_size = - std::min(static_cast(syclMaxWorkItemsPerEU()), 1024); + const uint32_t group_size = static_cast(syclMaxWorkItemsPerEU()); const uint32_t global_range = ceil_div(total_elements, group_size) * group_size; From cb5b094a05d400d43396ed2ee51f4fea67d06db0 Mon Sep 17 00:00:00 2001 From: Feng Yuan Date: Sat, 29 Jun 2024 22:11:22 +0800 Subject: [PATCH 14/16] Coding style --- src/ATen/native/xpu/AveragePool2d.cpp | 1 + .../native/xpu/sycl/AveragePool2dKernels.cpp | 44 ++++++++++--------- .../native/xpu/sycl/AveragePool2dKernels.h | 9 ++-- 3 files changed, 27 insertions(+), 27 deletions(-) diff --git a/src/ATen/native/xpu/AveragePool2d.cpp b/src/ATen/native/xpu/AveragePool2d.cpp index 423a37955..f54d7a2fc 100644 --- a/src/ATen/native/xpu/AveragePool2d.cpp +++ b/src/ATen/native/xpu/AveragePool2d.cpp @@ -69,6 +69,7 @@ Tensor& avg_pool2d_meta( outputHeight, outputWidth, memory_format); + /* resize output */ if (input.ndimension() == 3) { if (output.defined()) { diff --git a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp index 8f08668cc..c82b6777d 100644 --- a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp +++ b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp @@ -2,6 +2,7 @@ #include #include #include + #include #include #include @@ -19,7 +20,7 @@ inline int max(int a, int b) { } template -struct AvgPool2dFrameKernelFunctor { +struct AvgPool2dKernelFunctor { void operator()(sycl::nd_item<1> item) const { auto index = item.get_global_linear_id(); @@ -66,7 +67,7 @@ struct AvgPool2dFrameKernelFunctor { top_data_[index] = static_cast(aveval / divide_factor); } } - AvgPool2dFrameKernelFunctor( + AvgPool2dKernelFunctor( scalar_t* top_data, const scalar_t* bottom_data, int64_t total_elements, @@ -123,7 +124,7 @@ struct AvgPool2dFrameKernelFunctor { }; template -struct AvgPool2dChannelsLastFrameKernelFunctor { +struct AvgPool2dChannelsLastKernelFunctor { void operator()(sycl::nd_item<1> item) const { auto index = item.get_global_linear_id(); @@ -168,7 +169,7 @@ struct AvgPool2dChannelsLastFrameKernelFunctor { top_data_[index] = static_cast(aveval / divide_factor); } } - AvgPool2dChannelsLastFrameKernelFunctor( + AvgPool2dChannelsLastKernelFunctor( scalar_t* top_data, const scalar_t* bottom_data, int64_t total_elements, @@ -225,7 +226,7 @@ struct AvgPool2dChannelsLastFrameKernelFunctor { }; template -void avg_pool2d_channels_last_frame( +void launch_avg_pool2d_channels_last_kernel( const int total_elements, const Tensor& input, const int64_t channels, @@ -251,7 +252,7 @@ void avg_pool2d_channels_last_frame( const uint32_t global_range = ceil_div(total_elements, group_size) * group_size; - auto caller = AvgPool2dChannelsLastFrameKernelFunctor( + auto kfn = AvgPool2dChannelsLastKernelFunctor( top_data, bottom_data, total_elements, @@ -269,10 +270,11 @@ void avg_pool2d_channels_last_frame( divisor_override, count_include_pad, use_divisor); - sycl_kernel_submit(global_range, group_size, queue, caller); + sycl_kernel_submit(global_range, group_size, queue, kfn); } + template -void avg_pool2d_frame( +void launch_avg_pool2d_kernel( const int total_elements, const Tensor& input, const int64_t channels, @@ -298,7 +300,7 @@ void avg_pool2d_frame( const uint32_t global_range = ceil_div(total_elements, group_size) * group_size; - auto caller = AvgPool2dFrameKernelFunctor( + auto kfn = AvgPool2dKernelFunctor( top_data, bottom_data, total_elements, @@ -316,7 +318,7 @@ void avg_pool2d_frame( divisor_override, count_include_pad, use_divisor); - sycl_kernel_submit(global_range, group_size, queue, caller); + sycl_kernel_submit(global_range, group_size, queue, kfn); } template @@ -529,7 +531,7 @@ struct AvgPool2dBackwarKernelFunctor { }; template -void avg_pool2d_backward_channels_last_frame( +void launch_avg_pool2d_backward_channels_last_kernel( const index_t total_elements, const Tensor& grad_output, const int64_t channels, @@ -555,7 +557,7 @@ void avg_pool2d_backward_channels_last_frame( const uint32_t global_range = ceil_div(total_elements, group_size) * group_size; - auto caller = AvgPool2dChannelsLastBackwardKernelFunctor< + auto kfn = AvgPool2dChannelsLastBackwardKernelFunctor< scalar_t, accscalar_t, index_t>( @@ -576,11 +578,11 @@ void avg_pool2d_backward_channels_last_frame( divisor_override, count_include_pad, use_divisor); - sycl_kernel_submit(global_range, group_size, queue, caller); + sycl_kernel_submit(global_range, group_size, queue, kfn); } template -void avg_pool2d_backward_frame( +void launch_avg_pool2d_backward_kernel( const index_t total_elements, const Tensor& grad_output, const int64_t channels, @@ -606,7 +608,7 @@ void avg_pool2d_backward_frame( const uint32_t global_range = ceil_div(total_elements, group_size) * group_size; - auto caller = AvgPool2dBackwarKernelFunctor( + auto kfn = AvgPool2dBackwarKernelFunctor( top_data, bottom_data, total_elements, @@ -624,7 +626,7 @@ void avg_pool2d_backward_frame( divisor_override, count_include_pad, use_divisor); - sycl_kernel_submit(global_range, group_size, queue, caller); + sycl_kernel_submit(global_range, group_size, queue, kfn); } void avg_pool2d_kernel( @@ -675,7 +677,7 @@ void avg_pool2d_kernel( case MemoryFormat::ChannelsLast: { output.unsafeGetTensorImpl()->empty_tensor_restride( MemoryFormat::ChannelsLast); - avg_pool2d_channels_last_frame( + launch_avg_pool2d_channels_last_kernel( count, input, nInputPlane, @@ -696,7 +698,7 @@ void avg_pool2d_kernel( break; } case MemoryFormat::Contiguous: { - avg_pool2d_frame( + launch_avg_pool2d_kernel( count, input, nInputPlane, @@ -777,13 +779,13 @@ void avg_pool2d_backward_kernel( AT_DISPATCH_INDEX_TYPES( at::native::canUse32BitIndexMath(input, INT_MAX) ? ScalarType::Int : ScalarType::Long, - "avg_pool2d_backward_xpu_launcher", + "avg_pool2d_backward_xpu", [&] { switch (memory_format) { case MemoryFormat::ChannelsLast: { gradInput.unsafeGetTensorImpl()->empty_tensor_restride( MemoryFormat::ChannelsLast); - avg_pool2d_backward_channels_last_frame< + launch_avg_pool2d_backward_channels_last_kernel< scalar_t, accscalar_t, index_t>( @@ -807,7 +809,7 @@ void avg_pool2d_backward_kernel( break; } case MemoryFormat::Contiguous: { - avg_pool2d_backward_frame( + launch_avg_pool2d_backward_kernel( count, gradOutput, nInputPlane, diff --git a/src/ATen/native/xpu/sycl/AveragePool2dKernels.h b/src/ATen/native/xpu/sycl/AveragePool2dKernels.h index b53f23593..079235857 100644 --- a/src/ATen/native/xpu/sycl/AveragePool2dKernels.h +++ b/src/ATen/native/xpu/sycl/AveragePool2dKernels.h @@ -1,7 +1,6 @@ #include -#include -namespace at::native { -namespace xpu { + +namespace at::native::xpu { void avg_pool2d_kernel( const Tensor& input_, @@ -24,6 +23,4 @@ void avg_pool2d_backward_kernel( c10::optional divisor_override, Tensor& gradInput); -} // namespace xpu - -} // namespace at::native \ No newline at end of file +} // namespace at::native::xpu From 00831810f6ce09b22973d520a5f9f26e8cf121b4 Mon Sep 17 00:00:00 2001 From: Feng Yuan Date: Sat, 29 Jun 2024 22:48:54 +0800 Subject: [PATCH 15/16] Fixing compilation issues --- src/ATen/native/xpu/AveragePool2d.cpp | 4 +++- src/ATen/native/xpu/sycl/ActivationHardsigmoidKernels.cpp | 3 +-- src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp | 6 +++++- src/ATen/native/xpu/sycl/Indexing.cpp | 2 +- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/ATen/native/xpu/AveragePool2d.cpp b/src/ATen/native/xpu/AveragePool2d.cpp index f54d7a2fc..4d3cc1c0e 100644 --- a/src/ATen/native/xpu/AveragePool2d.cpp +++ b/src/ATen/native/xpu/AveragePool2d.cpp @@ -1,7 +1,9 @@ #include #include -#include +#include #include + +#include #include namespace at { diff --git a/src/ATen/native/xpu/sycl/ActivationHardsigmoidKernels.cpp b/src/ATen/native/xpu/sycl/ActivationHardsigmoidKernels.cpp index 62853c9fe..d2c53e2b0 100644 --- a/src/ATen/native/xpu/sycl/ActivationHardsigmoidKernels.cpp +++ b/src/ATen/native/xpu/sycl/ActivationHardsigmoidKernels.cpp @@ -39,8 +39,7 @@ void hardsigmoid_kernel(TensorIteratorBase& iter) { const opmath_t one_sixth(1.0f / 6.0f); const opmath_t three(3.0f); const opmath_t six(6.0f); - HardsigmoidFunctor f( - zero, one_sixth, three, six); + HardsigmoidFunctor f(zero, one_sixth, three, six); gpu_kernel(iter, f); }); } diff --git a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp index c82b6777d..f6c3abec8 100644 --- a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp +++ b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include @@ -809,7 +810,10 @@ void avg_pool2d_backward_kernel( break; } case MemoryFormat::Contiguous: { - launch_avg_pool2d_backward_kernel( + launch_avg_pool2d_backward_kernel< + scalar_t, + accscalar_t, + index_t>( count, gradOutput, nInputPlane, diff --git a/src/ATen/native/xpu/sycl/Indexing.cpp b/src/ATen/native/xpu/sycl/Indexing.cpp index 725a04281..f7fa13146 100644 --- a/src/ATen/native/xpu/sycl/Indexing.cpp +++ b/src/ATen/native/xpu/sycl/Indexing.cpp @@ -378,7 +378,7 @@ void index_add_kernel( source_.scalar_type(), "index_add_xpu", [&] { - AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_xpu", [&] () { + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_xpu", [&]() { TensorInfo index_info = getTensorInfo(index); index_info.collapseDims(); From cc5122fca0d64fe05e3bd184a1ce269b0025c9c8 Mon Sep 17 00:00:00 2001 From: Feng Yuan Date: Sun, 30 Jun 2024 17:13:56 +0800 Subject: [PATCH 16/16] Remove cases --- test/xpu/run_test_with_skip.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/xpu/run_test_with_skip.py b/test/xpu/run_test_with_skip.py index 28fd4e2b2..525b90ba8 100644 --- a/test/xpu/run_test_with_skip.py +++ b/test/xpu/run_test_with_skip.py @@ -800,6 +800,13 @@ def launch_test(test_case, skip_list=None, exe_list=None): # https://github.com/intel/torch-xpu-ops/issues/468 "test_dtypes_nn_functional_interpolate_bilinear_xpu", "test_dtypes_nn_functional_interpolate_bicubic_xpu", + + # Op impl aligns with CUDA on the supported dtypes. + # RuntimeError: "avg_pool2d_xpu" not implemented for 'Long'. + # Retrieve the case, once avg_pool1d is supported. Test infra will change claimed dtypes in test case once the op is listed + # in XPU supported operators. Then the case will work. + "test_noncontiguous_samples_nn_functional_avg_pool1d_xpu_int64", + "test_noncontiguous_samples_nn_functional_local_response_norm_xpu_int64" ) res += launch_test("test_ops_xpu.py", skip_list)