diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index 4ae5d89791727..e1525a4f4d6cb 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -970,7 +970,13 @@ bool TensorEquals(const Tensor& left, const Tensor& right) { } else if (left.size() == 0) { are_equal = true; } else { - if (!left.is_contiguous() || !right.is_contiguous()) { + const bool left_row_major_p = left.is_row_major(); + const bool left_column_major_p = left.is_column_major(); + const bool right_row_major_p = right.is_row_major(); + const bool right_column_major_p = right.is_column_major(); + + if (!(left_row_major_p && right_row_major_p) && + !(left_column_major_p && right_column_major_p)) { const auto& shape = left.shape(); if (shape != right.shape()) { are_equal = false; diff --git a/cpp/src/arrow/tensor-test.cc b/cpp/src/arrow/tensor-test.cc index 36e97434d28e5..4638cd7739bed 100644 --- a/cpp/src/arrow/tensor-test.cc +++ b/cpp/src/arrow/tensor-test.cc @@ -155,6 +155,56 @@ TEST(TestTensor, CountNonZeroForNonContiguousTensor) { AssertCountNonZero(t, 8); } +TEST(TestTensor, Equals) { + std::vector shape = {4, 4}; + + std::vector c_values = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + std::vector c_strides = {32, 8}; + Tensor tc1(int64(), Buffer::Wrap(c_values), shape, c_strides); + Tensor tc2(int64(), Buffer::Wrap(c_values), shape, c_strides); + + std::vector f_values = {1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16}; + Tensor tc3(int64(), Buffer::Wrap(f_values), shape, c_strides); + + std::vector f_strides = {8, 32}; + Tensor tf1(int64(), Buffer::Wrap(f_values), shape, f_strides); + Tensor tf2(int64(), Buffer::Wrap(c_values), shape, f_strides); + + std::vector nc_values = {1, 0, 5, 0, 9, 0, 13, 0, 2, 0, 6, 0, 10, 0, 14, 0, + 3, 0, 7, 0, 11, 0, 15, 0, 4, 0, 8, 0, 12, 0, 16, 0}; + std::vector nc_strides = {16, 64}; + Tensor tnc(int64(), Buffer::Wrap(nc_values), shape, nc_strides); + + ASSERT_TRUE(tc1.is_contiguous()); + ASSERT_TRUE(tc1.is_row_major()); + + ASSERT_TRUE(tf1.is_contiguous()); + ASSERT_TRUE(tf1.is_column_major()); + + ASSERT_FALSE(tnc.is_contiguous()); + + // same object + EXPECT_TRUE(tc1.Equals(tc1)); + EXPECT_TRUE(tf1.Equals(tf1)); + EXPECT_TRUE(tnc.Equals(tnc)); + + // different objects + EXPECT_TRUE(tc1.Equals(tc2)); + EXPECT_FALSE(tc1.Equals(tc3)); + + // row-major and column-major + EXPECT_TRUE(tc1.Equals(tf1)); + EXPECT_FALSE(tc3.Equals(tf1)); + + // row-major and non-contiguous + EXPECT_TRUE(tc1.Equals(tnc)); + EXPECT_FALSE(tc3.Equals(tnc)); + + // column-major and non-contiguous + EXPECT_TRUE(tf1.Equals(tnc)); + EXPECT_FALSE(tf2.Equals(tnc)); +} + TEST(TestNumericTensor, ElementAccessWithRowMajorStrides) { std::vector shape = {3, 4};