Skip to content

Commit

Permalink
Treat null comparison result as false
Browse files Browse the repository at this point in the history
  • Loading branch information
save-buffer committed Nov 4, 2021
1 parent 3ef49e2 commit b7d5a5d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 24 deletions.
16 changes: 11 additions & 5 deletions cpp/src/arrow/compute/exec/hash_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,9 @@ class HashJoinBasicImpl : public HashJoinImpl {
}
ARROW_DCHECK(mask.array()->offset == 0);
ARROW_DCHECK(mask.array()->length == static_cast<int64_t>(match_left.size()));
const uint8_t* bv = mask.array()->buffers[1]->data();
int num_rows = match_left.size();
const uint8_t* nulls = mask.array()->buffers[0]->data();
const uint8_t* comparisons = mask.array()->buffers[1]->data();
size_t num_rows = match_left.size();

match.clear();
no_match.clear();
Expand All @@ -358,13 +359,18 @@ class HashJoinBasicImpl : public HashJoinImpl {
int32_t irow = 0; // index into match_left
for (int32_t curr_left = 0; static_cast<size_t>(curr_left) < num_probed_rows;
curr_left++) {
int32_t advance_to = irow < num_rows ? match_left[irow] : num_probed_rows;
int32_t advance_to =
static_cast<size_t>(irow) < num_rows ? match_left[irow] : num_probed_rows;
while (curr_left < advance_to) {
no_match.push_back(curr_left++);
}
bool passed = false;
for (; irow < num_rows && match_left[irow] == curr_left; irow++) {
if (BitUtil::GetBit(bv, irow)) {
for (; static_cast<size_t>(irow) < num_rows && match_left[irow] == curr_left;
irow++) {
bool is_null = !BitUtil::GetBit(nulls, irow);
bool is_cmp_true = BitUtil::GetBit(comparisons, irow);
// We treat a null comparison result as false, like in SQL
if (!is_null && is_cmp_true) {
match_left[match_idx] = match_left[irow];
match_right[match_idx] = match_right[irow];
match_idx++;
Expand Down
47 changes: 28 additions & 19 deletions cpp/src/arrow/compute/exec/hash_join_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -382,17 +382,23 @@ struct RandomDataTypeVector {
data_types.push_back(RandomDataType::Random(rng, constraints));
}

std::string DataTypeToString(size_t i) {
std::stringstream ss;

if (!data_types[i].is_fixed_length) {
ss << "str[" << data_types[i].min_string_length << ".."
<< data_types[i].max_string_length << "]";
} else {
ss << "int[" << data_types[i].fixed_length << "]";
}
return ss.str();
}

void Print() {
for (size_t i = 0; i < data_types.size(); ++i) {
if (!data_types[i].is_fixed_length) {
std::cout << "str[" << data_types[i].min_string_length << ".."
<< data_types[i].max_string_length << "]";
SCOPED_TRACE("str[" + std::to_string(data_types[i].min_string_length) + ".." +
std::to_string(data_types[i].max_string_length) + "]");
} else {
std::cout << "int[" << data_types[i].fixed_length << "]";
SCOPED_TRACE("int[" + std::to_string(data_types[i].fixed_length) + "]");
}
std::string stringified = DataTypeToString(i);
std::cout << stringified;
SCOPED_TRACE(stringified);
}
std::cout << std::endl;
}
Expand Down Expand Up @@ -951,7 +957,6 @@ void HashJoinWithExecPlan(Random64Bit& rng, bool parallel,

TEST(HashJoin, Random) {
Random64Bit rng(42);

int num_tests = 100;
for (int test_id = 0; test_id < num_tests; ++test_id) {
bool parallel = (rng.from_range(0, 1) == 1);
Expand Down Expand Up @@ -1061,15 +1066,16 @@ TEST(HashJoin, Random) {

// Print test case parameters
// print num_rows, batch_size, join_type, join_cmp
std::cout << join_type_name << " " << key_cmp_str << " ";
std::cout << "Trial " << test_id << ":\n";
std::cout << " " << join_type_name << " " << key_cmp_str << " ";
key_types.Print();
std::cout << " payload_l: ";
std::cout << " payload_l: ";
payload_types[0].Print();
std::cout << " payload_r: ";
std::cout << " payload_r: ";
payload_types[1].Print();
std::cout << " num_rows_l = " << num_rows_l << " num_rows_r = " << num_rows_r
<< " batch size = " << batch_size
<< " parallel = " << (parallel ? "true" : "false");
std::cout << " num_rows_l = " << num_rows_l << " num_rows_r = " << num_rows_r
<< " batch size = " << batch_size
<< " parallel = " << (parallel ? "true" : "false");
std::cout << std::endl;
// Run reference join implementation
std::vector<bool> null_in_key_vectors[2];
Expand All @@ -1094,22 +1100,25 @@ TEST(HashJoin, Random) {
FieldRef left = key_fields[0][i];
FieldRef right = key_fields[1][i];

if (key_cmp[i] == JoinKeyCmp::IS && left.IsName() && right.IsName()) {
if (key_cmp[i] == JoinKeyCmp::EQ && left.IsName() && right.IsName()) {
key_fields[0].erase(key_fields[0].begin() + i);
key_fields[1].erase(key_fields[1].begin() + i);
key_cmp.erase(key_cmp.begin() + i);
Expression left_expr(field_ref(left));
Expression right_expr(field_ref(right));

filter = equal(left_expr, right_expr);
std::cout << " Filter comparing on type " << key_types.DataTypeToString(i)
<< '\n';
break;
}
}
}

if (!filter.IsEmpty()) {
std::cout << " Filter: " << filter.ToString() << "\n";
std::cout << " Filter: " << filter.ToString() << "\n";
} else {
std::cout << " Filter: <empty>\n";
std::cout << " Filter: <empty>\n";
}

// Run tested join implementation
Expand Down

0 comments on commit b7d5a5d

Please sign in to comment.