-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support join filter for SortMergeJoin
#9080
Conversation
7ab4393
to
d9e40af
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @viirya -- the code is looking pretty good to me. I think this PR may need to handle OUTER joins as well, but maybe it already does. Adding some test coverage could probably tell us one way or the other
# under the License. | ||
|
||
########## | ||
## Sort Merge Join Tests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should also add tests for LEFT/RIGHT OUTER joins where the filter needs to be applied to the non - preserved side (aka applied during the join)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, all supported Join types are supported. Let me add more test coverage.
PTAL @metesynnada |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe there should be more unit tests for numeric values as well. Other than that, the PR looks fine.
let mut streamed_columns = self | ||
.streamed_schema | ||
.fields() | ||
.iter() | ||
.map(|f| new_null_array(f.data_type(), buffered_indices.len())) | ||
.collect::<Vec<_>>(); | ||
|
||
let filter_columns = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inside get_filter_column
, the filter is checked if it is Some
. If not, the result will be an empty vector.
Instead of doing that, you can move the filter columns calculation under
let output_batch = if let Some(f) = &self.filter {
and make get_filter_column
expects joinfilter_filter: &JoinFilter
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is because buffered_columns
is consumed by streamed_columns.extend(buffered_columns);
. So either I clone buffered_columns
before it, or put get_filter_column
before it like I current do.
Thanks for review. Found this has some issues on OUTER joins, going to revise this and add more tests. |
Marking as a draft to make it clear this PR is not waiting on review |
33 c 3 NULL NULL NULL | ||
44 d 4 44 x 3 | ||
NULL NULL NULL 11 z 3 | ||
NULL NULL NULL 55 w 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is added to compare with sort_merge_join.slt results for full join.
Thank you @alamb. I fixed the issue for outer joins and marked this ready for review again. |
Thanks @viirya -- I plan to review this later today |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much @viirya -- I reviewed the code carefully and it looks like a nice improvement to me and has good test coverage. I had some code organization / comment suggestions, but nothing that I think would prevent merging
cc @korowa, @liukun4515 and @metesynnada in case you have some additional thoughts to share as I think you may be famililar with this code.
RecordBatch::try_new(self.schema.clone(), columns.clone())?; | ||
|
||
// Apply join filter if any | ||
if !filter_columns.is_empty() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't undersand why there is the check for filer columns and if self.filter
is Some. I expected the check to simply be if self.filter
is some (and the else
case is the same for both below)
If the filter has no columns, it seems like the else
clause does the same thing in both cases.
Thus, I wonder if we could remove the check for filter_columns
entirely 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because if this joined batch is between streamed batch and null (i.e., outer joins), we don't need to handle join filter (although join filter is Some
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will add a short comment here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see -- this would be for batches that don't have any matches from equality predicates anyways - which makes sense
self.join_type, | ||
JoinType::Left | JoinType::Right | JoinType::Full | ||
) { | ||
// The reverse of the selection mask, which is for null joined rows |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does 'null joined rows' mean 'rows that passed neither the equijoin predicates NOR the filter? If so I would find a term like 'non_matching_rows` easier to understand. But that is a personal preference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rows reaches here are all passed the equijoin predicates already (their buffered_batch_idx
is Some
). "null joined rows" here means the rows not pass the join filter, and we are going to join them (left or right side) with null. Let me add a few words to make it clear.
|
||
buffered_columns.extend(streamed_columns); | ||
buffered_columns | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I missed the fact that this handles left and full (not just left)
} else { | |
} | |
// Left join or full outer join | |
else { |
@@ -1142,12 +1294,49 @@ impl SMJStream { | |||
let record_batch = concat_batches(&self.schema, &self.output_record_batches)?; | |||
self.join_metrics.output_batches.add(1); | |||
self.join_metrics.output_rows.add(record_batch.num_rows()); | |||
self.output_size -= record_batch.num_rows(); | |||
// If join filter exists, `self.output_size` is not accurate as we don't know the exact |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the idea here that output_size
is tracking the number of rows remaining to output? If so, it seems like the filter
could only decrease the number of output rows (never increase it)
However, I can see how the SMJ code could overshoot for LEFT/RIGHT/FULL joins, so maybe this fix was needed because now there is more test coverage of SMJ 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic of output_size
assumes that each row put into the buffer will produce exactly one output row. It is increased when we put rows into buffer and decreased after we actually output batches.
So it is used to track the number of rows in buffers. We compare it with self.batch_size
(the target output batch size), and decide to output batches from buffers if it reaches.
For joins with join filter cases, the assumption of output_size
is broken. One row put into the buffer may produce more than one output row. For example, one joined row under full join doesn't pass join filter, then it will produce two output rows, i.e., streamed row joined null row and null row joined buffered row.
So the actual output rows record_batch.num_rows()
may be larger than self.output_size
and self.output_size -= record_batch.num_rows()
will cause overflow.
For such case, we can simply clean up output_size
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got it -- thank you for the explanation. I didn't understand the assumptions / invariants of output_size. Maybe we can clarify this somehow in comments (I left one suggestion, but maybe it is not correct)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @alamb. The suggestion looks good to me.
Co-authored-by: Andrew Lamb <[email protected]>
Thanks @alamb @metesynnada for review. |
Which issue does this PR close?
Closes #.
Rationale for this change
DataFusion
SortMergeJoin
doesn't support join filter for now. Any logical join operator with join filter could only be planned asHashJoinExec
which supports join filter.Spark
SortMergeJoin
supports join filter. Without join filter support, we cannot translate SparkSortMergeJoin
operator to DataFusion.What changes are included in this PR?
This patch adds join filter support to DataFusion
SortMergeJoin
.Are these changes tested?
Are there any user-facing changes?