diff --git a/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h b/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h index f864c2e5f0ed0..119c7ad52202b 100644 --- a/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h +++ b/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h @@ -19,6 +19,7 @@ #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/deformable_conv_functor.h" +#include "paddle/phi/kernels/transpose_kernel.h" #include "paddle/utils/optional.h" namespace phi { @@ -38,6 +39,11 @@ void DeformableConvKernel(const Context& dev_ctx, DenseTensor* out) { const int batch_size = static_cast(x.dims()[0]); + int temp_step = std::min(64, batch_size); + if (batch_size % temp_step == 0) { + im2col_step = temp_step; + } + std::vector filter_shape_vec(phi::vectorize(filter.dims())); std::vector output_shape_vec(phi::vectorize(out->dims())); @@ -101,8 +107,11 @@ void DeformableConvKernel(const Context& dev_ctx, dilations, deformable_groups, col_buffer_ptr); - DenseTensor output_3d = output_4d.Slice(i, i + 1).Resize( - phi::slice_ddim(output_4d.dims(), 1, output_4d.dims().size())); + DenseTensor output_3d = output_4d.Slice(i, i + 1).Resize(phi::slice_ddim( + output_4d.dims(), + 1, + output_4d.dims().size())); // group * C/group * (im2step * H * W) + // get the product of pixel and weight for (int g = 0; g < groups; ++g) { DenseTensor weight_3d_slice = weight_3d.Slice(g, g + 1).Resize( @@ -110,8 +119,11 @@ void DeformableConvKernel(const Context& dev_ctx, DenseTensor col_buffer_3d_slice = col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim( col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); - DenseTensor output_3d_slice = output_3d.Slice(g, g + 1).Resize( - phi::slice_ddim(output_3d.dims(), 1, output_3d.dims().size())); + DenseTensor output_3d_slice = + output_3d.Slice(g, g + 1).Resize(phi::slice_ddim( + output_3d.dims(), + 1, + output_3d.dims().size())); // C * ((im2col_step)*H*W)) blas.MatMul(weight_3d_slice, false, col_buffer_3d_slice, @@ -121,7 +133,29 @@ void DeformableConvKernel(const Context& dev_ctx, T(0.0)); } } - out->ShareDataWith(output_buffer).Resize(phi::make_ddim(output_shape_vec)); + + // swap axis to get the right result when im2col_step is greater than 1 + if (im2col_step > 1) { + std::vector axis(4); + axis[0] = 0; + axis[1] = 2; + axis[2] = 1; + axis[3] = 3; + + DenseTensor real_output_buffer = phi::Transpose( + dev_ctx, + output_4d.Resize( + phi::make_ddim({batch_size / im2col_step, + output_shape_vec[1], + im2col_step, + output_shape_vec[2] * output_shape_vec[3]})), + axis); + + out->ShareDataWith(real_output_buffer) + .Resize(phi::make_ddim(output_shape_vec)); + } else { + out->ShareDataWith(output_buffer).Resize(phi::make_ddim(output_shape_vec)); + } } } // namespace phi