Skip to content

Commit

Permalink
fix sparse csr (#42271)
Browse files Browse the repository at this point in the history
  • Loading branch information
tiancaishaonvjituizi authored Apr 27, 2022
1 parent d1e0123 commit b9bfcf1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion paddle/phi/api/lib/api_gen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ phi::TensorBase* SetSparseKernelOutput(Tensor* out, TensorType type) {
std::make_shared<phi::SparseCsrTensor>(phi::DenseTensor(),
phi::DenseTensor(),
phi::DenseTensor(),
phi::DDim{-1});
phi::DDim{-1, -1});
out->set_impl(sparse_tensor);
return sparse_tensor.get();
} else {
Expand Down
8 changes: 5 additions & 3 deletions paddle/phi/core/sparse_csr_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ SparseCsrTensor::SparseCsrTensor() {
inline void check_shape(const DDim& dims) {
bool valid = dims.size() == 2 || dims.size() == 3;

PADDLE_ENFORCE(valid,
phi::errors::InvalidArgument(
"the SparseCsrTensor only support 2-D Tensor."));
PADDLE_ENFORCE(
valid,
phi::errors::InvalidArgument("the SparseCsrTensor only support 2-D or "
"3-D Tensor, but get %d-D Tensor",
dims.size()));
}
#define Check(non_zero_crows, non_zero_cols, non_zero_elements, dims) \
{ \
Expand Down

0 comments on commit b9bfcf1

Please sign in to comment.