From b7d5a5d97b2057e141668ed45fc805b32df14640 Mon Sep 17 00:00:00 2001 From: Sasha Krassovsky Date: Tue, 2 Nov 2021 13:58:03 -0700 Subject: [PATCH] Treat null comparison result as false --- cpp/src/arrow/compute/exec/hash_join.cc | 16 +++++-- .../arrow/compute/exec/hash_join_node_test.cc | 47 +++++++++++-------- 2 files changed, 39 insertions(+), 24 deletions(-) diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index 3d8be95b22f67..7d646e6ccf92f 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -348,8 +348,9 @@ class HashJoinBasicImpl : public HashJoinImpl { } ARROW_DCHECK(mask.array()->offset == 0); ARROW_DCHECK(mask.array()->length == static_cast(match_left.size())); - const uint8_t* bv = mask.array()->buffers[1]->data(); - int num_rows = match_left.size(); + const uint8_t* nulls = mask.array()->buffers[0]->data(); + const uint8_t* comparisons = mask.array()->buffers[1]->data(); + size_t num_rows = match_left.size(); match.clear(); no_match.clear(); @@ -358,13 +359,18 @@ class HashJoinBasicImpl : public HashJoinImpl { int32_t irow = 0; // index into match_left for (int32_t curr_left = 0; static_cast(curr_left) < num_probed_rows; curr_left++) { - int32_t advance_to = irow < num_rows ? match_left[irow] : num_probed_rows; + int32_t advance_to = + static_cast(irow) < num_rows ? match_left[irow] : num_probed_rows; while (curr_left < advance_to) { no_match.push_back(curr_left++); } bool passed = false; - for (; irow < num_rows && match_left[irow] == curr_left; irow++) { - if (BitUtil::GetBit(bv, irow)) { + for (; static_cast(irow) < num_rows && match_left[irow] == curr_left; + irow++) { + bool is_null = !BitUtil::GetBit(nulls, irow); + bool is_cmp_true = BitUtil::GetBit(comparisons, irow); + // We treat a null comparison result as false, like in SQL + if (!is_null && is_cmp_true) { match_left[match_idx] = match_left[irow]; match_right[match_idx] = match_right[irow]; match_idx++; diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc index b4b1aecca4da9..ac65025a42afc 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc @@ -382,17 +382,23 @@ struct RandomDataTypeVector { data_types.push_back(RandomDataType::Random(rng, constraints)); } + std::string DataTypeToString(size_t i) { + std::stringstream ss; + + if (!data_types[i].is_fixed_length) { + ss << "str[" << data_types[i].min_string_length << ".." + << data_types[i].max_string_length << "]"; + } else { + ss << "int[" << data_types[i].fixed_length << "]"; + } + return ss.str(); + } + void Print() { for (size_t i = 0; i < data_types.size(); ++i) { - if (!data_types[i].is_fixed_length) { - std::cout << "str[" << data_types[i].min_string_length << ".." - << data_types[i].max_string_length << "]"; - SCOPED_TRACE("str[" + std::to_string(data_types[i].min_string_length) + ".." + - std::to_string(data_types[i].max_string_length) + "]"); - } else { - std::cout << "int[" << data_types[i].fixed_length << "]"; - SCOPED_TRACE("int[" + std::to_string(data_types[i].fixed_length) + "]"); - } + std::string stringified = DataTypeToString(i); + std::cout << stringified; + SCOPED_TRACE(stringified); } std::cout << std::endl; } @@ -951,7 +957,6 @@ void HashJoinWithExecPlan(Random64Bit& rng, bool parallel, TEST(HashJoin, Random) { Random64Bit rng(42); - int num_tests = 100; for (int test_id = 0; test_id < num_tests; ++test_id) { bool parallel = (rng.from_range(0, 1) == 1); @@ -1061,15 +1066,16 @@ TEST(HashJoin, Random) { // Print test case parameters // print num_rows, batch_size, join_type, join_cmp - std::cout << join_type_name << " " << key_cmp_str << " "; + std::cout << "Trial " << test_id << ":\n"; + std::cout << " " << join_type_name << " " << key_cmp_str << " "; key_types.Print(); - std::cout << " payload_l: "; + std::cout << " payload_l: "; payload_types[0].Print(); - std::cout << " payload_r: "; + std::cout << " payload_r: "; payload_types[1].Print(); - std::cout << " num_rows_l = " << num_rows_l << " num_rows_r = " << num_rows_r - << " batch size = " << batch_size - << " parallel = " << (parallel ? "true" : "false"); + std::cout << " num_rows_l = " << num_rows_l << " num_rows_r = " << num_rows_r + << " batch size = " << batch_size + << " parallel = " << (parallel ? "true" : "false"); std::cout << std::endl; // Run reference join implementation std::vector null_in_key_vectors[2]; @@ -1094,7 +1100,7 @@ TEST(HashJoin, Random) { FieldRef left = key_fields[0][i]; FieldRef right = key_fields[1][i]; - if (key_cmp[i] == JoinKeyCmp::IS && left.IsName() && right.IsName()) { + if (key_cmp[i] == JoinKeyCmp::EQ && left.IsName() && right.IsName()) { key_fields[0].erase(key_fields[0].begin() + i); key_fields[1].erase(key_fields[1].begin() + i); key_cmp.erase(key_cmp.begin() + i); @@ -1102,14 +1108,17 @@ TEST(HashJoin, Random) { Expression right_expr(field_ref(right)); filter = equal(left_expr, right_expr); + std::cout << " Filter comparing on type " << key_types.DataTypeToString(i) + << '\n'; break; } } } + if (!filter.IsEmpty()) { - std::cout << " Filter: " << filter.ToString() << "\n"; + std::cout << " Filter: " << filter.ToString() << "\n"; } else { - std::cout << " Filter: \n"; + std::cout << " Filter: \n"; } // Run tested join implementation