Skip to content

Commit

Permalink
Fix TensorEquals for different contiguous tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
mrkn committed Jul 1, 2019
1 parent 91b4cbc commit e1b9b9e
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
8 changes: 7 additions & 1 deletion cpp/src/arrow/compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,13 @@ bool TensorEquals(const Tensor& left, const Tensor& right) {
} else if (left.size() == 0) {
are_equal = true;
} else {
if (!left.is_contiguous() || !right.is_contiguous()) {
const bool left_row_major_p = left.is_row_major();
const bool left_column_major_p = left.is_column_major();
const bool right_row_major_p = right.is_row_major();
const bool right_column_major_p = right.is_column_major();

if (!(left_row_major_p && right_row_major_p) &&
!(left_column_major_p && right_column_major_p)) {
const auto& shape = left.shape();
if (shape != right.shape()) {
are_equal = false;
Expand Down
51 changes: 51 additions & 0 deletions cpp/src/arrow/tensor-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,57 @@ TEST(TestTensor, CountNonZeroForNonContiguousTensor) {
AssertCountNonZero(t, 8);
}

TEST(TestTensor, Equals) {
std::vector<int64_t> shape = {4, 4};

std::vector<int64_t> c_values = { 1, 2, 3, 4,
5, 6, 7, 8,
9, 10, 11, 12,
13, 14, 15, 16};
std::vector<int64_t> c_strides = {32, 8};
Tensor tc1(int64(), Buffer::Wrap(c_values), shape, c_strides);
Tensor tc2(int64(), Buffer::Wrap(c_values), shape, c_strides);

std::vector<int64_t> f_values = { 1, 5, 9, 13,
2, 6, 10, 14,
3, 7, 11, 15,
4, 8, 12, 16};
std::vector<int64_t> f_strides = {8, 32};
Tensor tf1(int64(), Buffer::Wrap(f_values), shape, f_strides);

std::vector<int64_t> nc_values = { 1, 0, 5, 0, 9, 0, 13, 0,
2, 0, 6, 0, 10, 0, 14, 0,
3, 0, 7, 0, 11, 0, 15, 0,
4, 0, 8, 0, 12, 0, 16, 0};
std::vector<int64_t> nc_strides = {16, 64};
Tensor tnc(int64(), Buffer::Wrap(nc_values), shape, nc_strides);

ASSERT_TRUE(tc1.is_contiguous());
ASSERT_TRUE(tc1.is_row_major());

ASSERT_TRUE(tf1.is_contiguous());
ASSERT_TRUE(tf1.is_column_major());

ASSERT_FALSE(tnc.is_contiguous());

// same object
EXPECT_TRUE(tc1.Equals(tc1));
EXPECT_TRUE(tf1.Equals(tf1));
EXPECT_TRUE(tnc.Equals(tnc));

// different objects
EXPECT_TRUE(tc1.Equals(tc2));

// row-major and column-major
EXPECT_TRUE(tc1.Equals(tf1));

// row-major and non-contiguous
EXPECT_TRUE(tc1.Equals(tnc));

// column-major and non-contiguous
EXPECT_TRUE(tf1.Equals(tnc));
}

TEST(TestNumericTensor, ElementAccessWithRowMajorStrides) {
std::vector<int64_t> shape = {3, 4};

Expand Down

0 comments on commit e1b9b9e

Please sign in to comment.