Skip to content

Commit

Permalink
Add sparse conv3d kernel (#39879)
Browse files Browse the repository at this point in the history
* fix incorrect dims settings

* sparse conv3d

* fix out dims

* test performance

* test large shape success

* opt scatter, double performance

* test float16

* remove profiling code

* remove pten

* opt code lines

* correct boundary judgment

* only cpu

* test ci

* test ci

* remove the including paddle/fluid header; extract the conmmon function

* opt code lines

* use DenseTensor::data() instead of mutable_data

* return rulebook for backward

* specify layout

* rename:conv -> sparse_conv3d
  • Loading branch information
zhangkaihuo authored Feb 28, 2022
1 parent 61443a0 commit bc99a76
Show file tree
Hide file tree
Showing 9 changed files with 967 additions and 15 deletions.
2 changes: 1 addition & 1 deletion paddle/phi/core/sparse_coo_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ void SparseCooTensor::SetMember(const DenseTensor& non_zero_indices,
const bool coalesced) {
this->non_zero_indices_ = non_zero_indices;
this->non_zero_elements_ = non_zero_elements;
this->dims_ = dims_;
this->dims_ = dims;
this->coalesced_ = coalesced;
}

Expand Down
148 changes: 148 additions & 0 deletions paddle/phi/kernels/sparse/convolution_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/kernels/empty_kernel.h"

namespace phi {
namespace sparse {

struct Dims4D {
int dims[4];
Dims4D(const int batch, const int x, const int y, const int z) {
dims[0] = batch;
dims[1] = z;
dims[2] = y;
dims[3] = x;
}
HOSTDEVICE const int& operator[](int i) const { return dims[i]; }
};

// Judge whether the current position x is in (lower, upper)
inline HOSTDEVICE bool Check(const int& x,
const int& kx,
const int& pad,
const int& stride,
const int dilation,
const int kdim,
const int xdim) {
const int lower = x - dilation * kx + pad;
const int uper = x + (kdim - kx - 1) * dilation - pad;
return (lower >= 0 && lower % stride == 0 && uper < xdim);
}

// Check whether the current position(x, y, z) is legal:
// Judge the minimum and maximum values at each latitude
inline HOSTDEVICE bool Check(const Dims4D& dims,
const Dims4D& kernel_dims,
const Dims4D& paddings,
const Dims4D& dilations,
const Dims4D& strides,
const int x,
const int y,
const int z,
const int kx,
const int ky,
const int kz) {
bool x_valid = Check(
x, kx, paddings[3], strides[3], dilations[3], kernel_dims[3], dims[3]);
bool y_valid = Check(
y, ky, paddings[2], strides[2], dilations[2], kernel_dims[2], dims[2]);
bool z_valid = Check(
z, kz, paddings[1], strides[1], dilations[1], kernel_dims[1], dims[1]);
return (x_valid && y_valid && z_valid);
}

template <typename Dim>
inline HOSTDEVICE int PointToIndex(const int& batch,
const int& x,
const int& y,
const int& z,
const Dim& dims) {
return batch * dims[1] * dims[2] * dims[3] + z * dims[2] * dims[3] +
y * dims[3] + x;
}

template <typename Dim>
inline HOSTDEVICE void IndexToPoint(
const int index, const Dim& dims, int* batch, int* x, int* y, int* z) {
int n = index;
*x = n % dims[3];
n /= dims[3];
*y = n % dims[2];
n /= dims[2];
*z = n % dims[1];
n /= dims[1];
*batch = n;
}

inline void GetOutShape(const DDim& x_dims,
const DDim& kernel_dims,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
DDim* out_dims) {
PADDLE_ENFORCE_EQ(
x_dims.size(),
5,
phi::errors::InvalidArgument("the shape of x should be (N, D, H, W, C)"));
PADDLE_ENFORCE_EQ(kernel_dims.size(),
5,
phi::errors::InvalidArgument(
"the shape of kernel should be (D, H, W, C, OC)"));

// infer out shape
(*out_dims)[0] = x_dims[0];
(*out_dims)[4] = kernel_dims[4];
for (int i = 1; i < 4; i++) {
(*out_dims)[i] = (x_dims[i] + 2 * paddings[i - 1] -
dilations[i - 1] * (kernel_dims[i - 1] - 1) - 1) /
strides[i - 1] +
1;
}
}

template <typename T, typename Context>
void Conv3dKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& kernel,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
SparseCooTensor* out,
DenseTensor* rulebook);

template <typename T, typename Context>
SparseCooTensor Conv3d(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor kernel,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
DenseTensor* rulebook) {
DenseTensor indices = phi::Empty<T, Context>(dev_ctx);
DenseTensor values = phi::Empty<T, Context>(dev_ctx);
SparseCooTensor coo(indices, values, x.dims());
Conv3dKernel<T, Context>(
dev_ctx, x, kernel, paddings, dilations, strides, groups, &coo, rulebook);
return coo;
}

} // namespace sparse
} // namespace phi
181 changes: 181 additions & 0 deletions paddle/phi/kernels/sparse/cpu/convolution.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <set>

#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"

namespace phi {
namespace sparse {

// such as: kernel(3, 3, 3), kernel_size = 27
// counter_per_weight: (kernel_size)
// TODO(zhangkaihuo): optimize performance with multithreading
template <typename T, typename Context>
void ProductRuleBook(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& kernel,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const DDim& out_dims,
DenseTensor* rulebook,
DenseTensor* counter_per_kernel) {
const auto& kernel_dims = kernel.dims();
const int64_t non_zero_num = x.nnz();
const auto& non_zero_indices = x.non_zero_indices();
const int* indices_ptr = non_zero_indices.data<int>();
dev_ctx.Alloc(counter_per_kernel,
counter_per_kernel->dtype(),
sizeof(int) * counter_per_kernel->numel());
int* counter_ptr = counter_per_kernel->data<int>();
int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2];
memset(counter_ptr, 0, kernel_size * sizeof(int));

int rulebook_len = 0;
// calc the rulebook_len
const auto& x_dims = x.dims();
const Dims4D c_x_dims(x_dims[0], x_dims[3], x_dims[2], x_dims[1]);
const Dims4D c_kernel_dims(1, kernel_dims[2], kernel_dims[1], kernel_dims[0]);
const Dims4D c_out_dims(out_dims[0], out_dims[3], out_dims[2], out_dims[1]);
const Dims4D c_paddings(1, paddings[2], paddings[1], paddings[0]);
const Dims4D c_strides(1, strides[2], strides[1], strides[0]);
const Dims4D c_dilations(1, dilations[2], dilations[1], dilations[0]);

auto f_calc_rulebook = [&](int* rulebook_ptr) {
int kernel_index = 0, rulebook_index = 0;
for (int kz = 0; kz < kernel_dims[0]; kz++) {
for (int ky = 0; ky < kernel_dims[1]; ky++) {
for (int kx = 0; kx < kernel_dims[2]; kx++) {
for (int64_t i = 0; i < non_zero_num; i++) {
int batch = indices_ptr[i];
int in_z = indices_ptr[i + non_zero_num];
int in_y = indices_ptr[i + 2 * non_zero_num];
int in_x = indices_ptr[i + 3 * non_zero_num];
int out_z = (in_z + paddings[0] - kz * dilations[0]) / strides[0];
int out_y = (in_y + paddings[1] - ky * dilations[1]) / strides[1];
int out_x = (in_x + paddings[2] - kx * dilations[2]) / strides[2];
if (Check(c_x_dims,
c_kernel_dims,
c_paddings,
c_dilations,
c_strides,
in_x,
in_y,
in_z,
kx,
ky,
kz)) {
if (rulebook_ptr == nullptr) {
counter_ptr[kernel_index] += 1;
++rulebook_len;
} else {
rulebook_ptr[rulebook_index] = kernel_index;
rulebook_ptr[rulebook_index + rulebook_len] = i; // in_i
rulebook_ptr[rulebook_index + rulebook_len * 2] =
PointToIndex<DDim>(
batch, out_x, out_y, out_z, out_dims); // out_index
++rulebook_index;
}
}
}
++kernel_index;
}
}
}
};

f_calc_rulebook(nullptr);
// alloc the rulebook
rulebook->ResizeAndAllocate({3, rulebook_len});
dev_ctx.Alloc(rulebook, rulebook->dtype(), rulebook->numel() * sizeof(int));
int* rulebook_ptr = rulebook->data<int>();
f_calc_rulebook(rulebook_ptr);
}

template <typename T, typename Context>
void UpdateRulebookAndOutIndex(const Context& dev_ctx,
const SparseCooTensor& x,
const int kernel_size,
const int out_channels,
const DDim& out_dims,
DenseTensor* rulebook,
SparseCooTensor* out) {
std::set<int> out_indexs;
int n = rulebook->dims()[1];
int* rulebook_ptr = rulebook->data<int>();
for (int i = 0; i < n; i++) {
out_indexs.insert(rulebook_ptr[i + n * 2]);
}

int out_non_zero_num = out_indexs.size();
const int64_t sparse_dim = 4;
DenseTensorMeta indices_meta(
DataType::INT32, {sparse_dim, out_non_zero_num}, DataLayout::NCHW);
DenseTensorMeta values_meta(
x.dtype(), {out_non_zero_num, out_channels}, x.layout());
phi::DenseTensor out_indices = phi::Empty(dev_ctx, std::move(indices_meta));
phi::DenseTensor out_values = phi::Empty(dev_ctx, std::move(values_meta));
dev_ctx.Alloc(
&out_indices, out_indices.dtype(), out_indices.numel() * sizeof(int));
int* out_indices_ptr = out_indices.data<int>();
int i = 0;
for (auto it = out_indexs.begin(); it != out_indexs.end(); it++, i++) {
const int index = *it;
int batch, x, y, z;
IndexToPoint<DDim>(index, out_dims, &batch, &x, &y, &z);
out_indices_ptr[i] = batch;
out_indices_ptr[i + out_non_zero_num] = z;
out_indices_ptr[i + out_non_zero_num * 2] = y;
out_indices_ptr[i + out_non_zero_num * 3] = x;
}
for (i = 0; i < n; i++) {
int out_index = rulebook_ptr[i + n * 2];
rulebook_ptr[i + n * 2] =
std::distance(out_indexs.begin(), out_indexs.find(out_index));
}

out->SetMember(out_indices, out_values, out_dims, true);
}

template <typename T>
void Gather(
const T* x, const int* indexs, const int n, const int channels, T* out) {
for (int i = 0; i < n; i++) {
int real_i = indexs[i];
memcpy(out + i * channels, x + real_i * channels, channels * sizeof(T));
}
}

template <typename T>
void Scatter(
const T* x, const int* indexs, const int n, const int channels, T* out) {
for (int i = 0; i < n; i++) {
int real_i = indexs[i];
for (int j = 0; j < channels; j++) {
out[real_i * channels + j] += x[i * channels + j];
}
}
}

} // namespace sparse
} // namespace phi
Loading

0 comments on commit bc99a76

Please sign in to comment.