Skip to content

Commit

Permalink
change axes_arr, starts_arr and ends_arr to axes, starts and ends
Browse files Browse the repository at this point in the history
  • Loading branch information
ScottWong98 committed Jun 2, 2023
1 parent 00e2eca commit ea48959
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 143 deletions.
87 changes: 56 additions & 31 deletions paddle/phi/kernels/sparse/cpu/slice_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,14 @@ namespace phi {
namespace sparse {

template <typename T, typename Context>
void SliceCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& out_grad,
const phi::IntArray& axes_arr,
const phi::IntArray& starts_arr,
const phi::IntArray& ends_arr,
SparseCooTensor* x_grad) {
const phi::DDim& x_dims = x.dims();

std::vector<int64_t> axes = axes_arr.GetData();
std::vector<int64_t> starts = starts_arr.GetData();
std::vector<int64_t> ends = ends_arr.GetData();

// Step1: update starts and ends
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(x_dims, &axes, &starts, &ends);

// Step2: set x_grad
void SliceCooGradCompute(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCooTensor* x_grad) {
// set x_grad
const int64_t out_grad_nnz = out_grad.nnz();
auto sparse_dim = static_cast<int64_t>(out_grad.sparse_dim());
DenseTensor dx_indices =
Expand Down Expand Up @@ -69,6 +60,27 @@ void SliceCooGradKernel(const Context& dev_ctx,
x_grad->SetMember(dx_indices, dx_values, x.dims(), x.coalesced());
}

template <typename T, typename Context>
void SliceCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& out_grad,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCooTensor* x_grad) {
const phi::DDim& x_dims = x.dims();
std::vector<int64_t> axes_vec = axes.GetData();
std::vector<int64_t> starts_vec = starts.GetData();
std::vector<int64_t> ends_vec = ends.GetData();

// update starts and ends
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(
x_dims, &axes_vec, &starts_vec, &ends_vec);

SliceCooGradCompute<T, Context>(
dev_ctx, x, out_grad, axes_vec, starts_vec, ends_vec, x_grad);
}

template <typename T>
void GetCsrInputGradCrows(const int64_t* out_grad_crows_data,
const int64_t out_grad_n_rows,
Expand Down Expand Up @@ -184,22 +196,15 @@ void SliceCsrGrad3D(const Context& dev_ctx,
}

template <typename T, typename Context>
void SliceCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const phi::IntArray& axes_arr,
const phi::IntArray& starts_arr,
const phi::IntArray& ends_arr,
SparseCsrTensor* x_grad) {
void SliceCsrGradCompute(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCsrTensor* x_grad) {
const phi::DDim& x_dims = x.dims();

std::vector<int64_t> axes = axes_arr.GetData();
std::vector<int64_t> starts = starts_arr.GetData();
std::vector<int64_t> ends = ends_arr.GetData();

// Update starts and ends
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(x_dims, &axes, &starts, &ends);

// Construct new axes, starts, and ends
std::vector<int64_t> new_axes(3), new_starts(3), new_ends(3);
funcs::ConstructNewSliceAttrs(
Expand All @@ -221,6 +226,26 @@ void SliceCsrGradKernel(const Context& dev_ctx,
}
}

template <typename T, typename Context>
void SliceCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCsrTensor* x_grad) {
const phi::DDim& x_dims = x.dims();
std::vector<int64_t> axes_vec = axes.GetData();
std::vector<int64_t> starts_vec = starts.GetData();
std::vector<int64_t> ends_vec = ends.GetData();

// Update starts and ends
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(
x_dims, &axes_vec, &starts_vec, &ends_vec);

SliceCsrGradCompute<T, Context>(
dev_ctx, x, out_grad, axes_vec, starts_vec, ends_vec, x_grad);
}
} // namespace sparse
} // namespace phi

Expand Down
86 changes: 54 additions & 32 deletions paddle/phi/kernels/sparse/cpu/slice_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,19 @@ namespace phi {
namespace sparse {

template <typename T, typename Context>
void SliceCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& axes_arr,
const phi::IntArray& starts_arr,
const phi::IntArray& ends_arr,
SparseCooTensor* out) {
void SliceCooCompute(const Context& dev_ctx,
const SparseCooTensor& x,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCooTensor* out) {
const phi::DDim& x_dims = x.dims();

std::vector<int64_t> axes = axes_arr.GetData();
std::vector<int64_t> starts = starts_arr.GetData();
std::vector<int64_t> ends = ends_arr.GetData();

// Step1: Check and update attr
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(x_dims, &axes, &starts, &ends);

// Step2: Infer output dims
// Step1: Infer output dims
auto out_dims = funcs::GetSliceDims<int64_t>(
x_dims, axes, starts, ends, nullptr, nullptr);

// Step3: Get out_nnz (the number of non-zero elements in output)
// Step2: Get out_nnz (the number of non-zero elements in output)
const int64_t x_nnz = x.nnz();
int64_t out_nnz = 0;
const auto* x_indices_data = x.indices().data<int64_t>();
Expand All @@ -60,7 +53,7 @@ void SliceCooKernel(const Context& dev_ctx,
out_nnz++;
}

// Step4: Get the values and indices of output
// Step3: Get the values and indices of output
auto sparse_dim = static_cast<int64_t>(x.sparse_dim());
DenseTensor out_indices =
phi::Empty<int64_t, Context>(dev_ctx, {sparse_dim, out_nnz});
Expand Down Expand Up @@ -95,6 +88,25 @@ void SliceCooKernel(const Context& dev_ctx,
out->SetMember(out_indices, out_values, out_dims, x.coalesced());
}

template <typename T, typename Context>
void SliceCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCooTensor* out) {
const phi::DDim& x_dims = x.dims();
std::vector<int64_t> axes_vec = axes.GetData();
std::vector<int64_t> starts_vec = starts.GetData();
std::vector<int64_t> ends_vec = ends.GetData();

// Check and update attr
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(
x_dims, &axes_vec, &starts_vec, &ends_vec);

SliceCooCompute<T, Context>(dev_ctx, x, axes_vec, starts_vec, ends_vec, out);
}

int64_t GetCsrNonZeroNumber(const SparseCsrTensor& x,
const int64_t x_crows_start,
const int64_t x_crows_end,
Expand Down Expand Up @@ -242,31 +254,24 @@ void SliceCsrTensor3D(const Context& dev_ctx,
}

template <typename T, typename Context>
void SliceCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const phi::IntArray& axes_arr,
const phi::IntArray& starts_arr,
const phi::IntArray& ends_arr,
SparseCsrTensor* out) {
void SliceCsrCompute(const Context& dev_ctx,
const SparseCsrTensor& x,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCsrTensor* out) {
const phi::DDim& x_dims = x.dims();

std::vector<int64_t> axes = axes_arr.GetData();
std::vector<int64_t> starts = starts_arr.GetData();
std::vector<int64_t> ends = ends_arr.GetData();

// Step1: Check and update attr
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(x_dims, &axes, &starts, &ends);

// Step2: Infer output dims
// Step1: Infer output dims
auto out_dims = funcs::GetSliceDims<int64_t>(
x_dims, axes, starts, ends, nullptr, nullptr);

// Step3: Construct new axes, starts and ends.
// Step2: Construct new axes, starts and ends.
std::vector<int64_t> new_axes(3), new_starts(3), new_ends(3);
funcs::ConstructNewSliceAttrs(
x_dims, axes, starts, ends, &new_axes, &new_starts, &new_ends);

// Setp4: Slice csr tensor according to its dimension
// Setp3: Slice csr tensor according to its dimension
if (x_dims.size() == 2) {
SliceCsrTensor2D<T, Context>(
dev_ctx, x, new_axes, new_starts, new_ends, out_dims, out);
Expand All @@ -281,6 +286,23 @@ void SliceCsrKernel(const Context& dev_ctx,
}
}

template <typename T, typename Context>
void SliceCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCsrTensor* out) {
const phi::DDim& x_dims = x.dims();
std::vector<int64_t> axes_vec = axes.GetData();
std::vector<int64_t> starts_vec = starts.GetData();
std::vector<int64_t> ends_vec = ends.GetData();

// Check and update attr
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(
x_dims, &axes_vec, &starts_vec, &ends_vec);
SliceCsrCompute<T, Context>(dev_ctx, x, axes_vec, starts_vec, ends_vec, out);
}
} // namespace sparse
} // namespace phi

Expand Down
83 changes: 55 additions & 28 deletions paddle/phi/kernels/sparse/gpu/slice_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,15 @@ __global__ void GetCooInputGradCudaKernel(const int64_t* out_grad_indices_data,
}
}
template <typename T, typename Context>
void SliceCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& out_grad,
const phi::IntArray& axes_arr,
const phi::IntArray& starts_arr,
const phi::IntArray& ends_arr,
SparseCooTensor* x_grad) {
void SliceCooGradCompute(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCooTensor* x_grad) {
const phi::DDim& x_dims = x.dims();

std::vector<int64_t> axes = axes_arr.GetData();
std::vector<int64_t> starts = starts_arr.GetData();
std::vector<int64_t> ends = ends_arr.GetData();

// Step1: Check and update sparse slice attrs
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(x_dims, &axes, &starts, &ends);

// copy axes to device
auto d_axes_tensor = memory_utils::Alloc(
dev_ctx.GetPlace(),
Expand Down Expand Up @@ -123,6 +116,26 @@ void SliceCooGradKernel(const Context& dev_ctx,
dx_values_data);
}

template <typename T, typename Context>
void SliceCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& out_grad,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCooTensor* x_grad) {
const phi::DDim& x_dims = x.dims();
std::vector<int64_t> axes_vec = axes.GetData();
std::vector<int64_t> starts_vec = starts.GetData();
std::vector<int64_t> ends_vec = ends.GetData();
// Check and update sparse slice attrs
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(
x_dims, &axes_vec, &starts_vec, &ends_vec);

SliceCooGradCompute<T, Context>(
dev_ctx, x, out_grad, axes_vec, starts_vec, ends_vec, x_grad);
}

template <typename T>
__global__ void GetCsrInputColsValuesCudaKernel(
const int64_t* out_grad_cols_data,
Expand Down Expand Up @@ -283,22 +296,15 @@ void SliceCsrGrad3D(const Context& dev_ctx,
}

template <typename T, typename Context>
void SliceCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const phi::IntArray& axes_arr,
const phi::IntArray& starts_arr,
const phi::IntArray& ends_arr,
SparseCsrTensor* x_grad) {
void SliceCsrGradCompute(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCsrTensor* x_grad) {
const phi::DDim& x_dims = x.dims();

std::vector<int64_t> axes = axes_arr.GetData();
std::vector<int64_t> starts = starts_arr.GetData();
std::vector<int64_t> ends = ends_arr.GetData();

// update starts and ends
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(x_dims, &axes, &starts, &ends);

// construct new axes, starts, and ends
std::vector<int64_t> new_axes(3), new_starts(3), new_ends(3);
funcs::ConstructNewSliceAttrs(
Expand All @@ -319,6 +325,27 @@ void SliceCsrGradKernel(const Context& dev_ctx,
x_dims.size());
}
}

template <typename T, typename Context>
void SliceCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCsrTensor* x_grad) {
const phi::DDim& x_dims = x.dims();
std::vector<int64_t> axes_vec = axes.GetData();
std::vector<int64_t> starts_vec = starts.GetData();
std::vector<int64_t> ends_vec = ends.GetData();
// update starts and ends
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(
x_dims, &axes_vec, &starts_vec, &ends_vec);

SliceCsrGradCompute<T, Context>(
dev_ctx, x, out_grad, axes_vec, starts_vec, ends_vec, x_grad);
}

} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(slice_coo_grad,
Expand Down
Loading

0 comments on commit ea48959

Please sign in to comment.