Skip to content

Commit

Permalink
Infer the count of maximum distinct values from min/max (#3837)
Browse files Browse the repository at this point in the history
* Infer the count of maximum distinct values from min/max

* Even if the delta is 0, ensure that the distinct count is 1 (when min=max)
  • Loading branch information
isidentical authored Oct 15, 2022
1 parent e02376d commit fe0000e
Showing 1 changed file with 196 additions and 31 deletions.
227 changes: 196 additions & 31 deletions datafusion/core/src/physical_plan/join_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::logical_expr::JoinType;
use crate::physical_plan::expressions::Column;
use arrow::datatypes::{Field, Schema};
use arrow::error::ArrowError;
use datafusion_common::ScalarValue;
use datafusion_physical_expr::PhysicalExpr;
use futures::future::{BoxFuture, Shared};
use futures::{ready, FutureExt};
Expand Down Expand Up @@ -423,7 +424,9 @@ fn estimate_inner_join_cardinality(
return None;
}

let max_distinct = max(left_stat.distinct_count, right_stat.distinct_count);
let left_max_distinct = max_distinct_count(left_num_rows, left_stat.clone());
let right_max_distinct = max_distinct_count(right_num_rows, right_stat.clone());
let max_distinct = max(left_max_distinct, right_max_distinct);
if max_distinct > join_selectivity {
// Seems like there are a few implementations of this algorithm that implement
// exponential decay for the selectivity (like Hive's Optiq Optimizer). Needs
Expand All @@ -447,6 +450,50 @@ fn estimate_inner_join_cardinality(
}
}

/// Estimate the number of maximum distinct values that can be present in the
/// given column from its statistics.
///
/// If distinct_count is available, uses it directly. If the column numeric, and
/// has min/max values, then they might be used as a fallback option. Otherwise,
/// returns None.
fn max_distinct_count(num_rows: usize, stats: ColumnStatistics) -> Option<usize> {
match (stats.distinct_count, stats.max_value, stats.min_value) {
(Some(_), _, _) => stats.distinct_count,
(_, Some(max), Some(min)) => {
// Note that float support is intentionally omitted here, since the computation
// of a range between two float values is not trivial and the result would be
// highly inaccurate.
let numeric_range = get_int_range(min, max)?;

// The number can never be greater than the number of rows we have (minus
// the nulls, since they don't count as distinct values).
let ceiling = num_rows - stats.null_count.unwrap_or(0);
Some(numeric_range.min(ceiling))
}
_ => None,
}
}

/// Return the numeric range between the given min and max values.
fn get_int_range(min: ScalarValue, max: ScalarValue) -> Option<usize> {
let delta = &max.sub(&min).ok()?;
match delta {
ScalarValue::Int8(Some(delta)) if *delta >= 0 => Some(*delta as usize),
ScalarValue::Int16(Some(delta)) if *delta >= 0 => Some(*delta as usize),
ScalarValue::Int32(Some(delta)) if *delta >= 0 => Some(*delta as usize),
ScalarValue::Int64(Some(delta)) if *delta >= 0 => Some(*delta as usize),
ScalarValue::UInt8(Some(delta)) => Some(*delta as usize),
ScalarValue::UInt16(Some(delta)) => Some(*delta as usize),
ScalarValue::UInt32(Some(delta)) => Some(*delta as usize),
ScalarValue::UInt64(Some(delta)) => Some(*delta as usize),
_ => None,
}
// The delta (directly) is not the real range, since it does not include the
// first term.
// E.g. (min=2, max=4) -> (4 - 2) -> 2, but the actual result should be 3 (1, 2, 3).
.map(|open_ended_range| open_ended_range + 1)
}

enum OnceFutState<T> {
Pending(OnceFutPending<T>),
Ready(Arc<Result<T>>),
Expand Down Expand Up @@ -626,19 +673,19 @@ mod tests {
}

fn create_column_stats(
min: Option<u64>,
max: Option<u64>,
min: Option<i64>,
max: Option<i64>,
distinct_count: Option<usize>,
) -> ColumnStatistics {
ColumnStatistics {
distinct_count,
min_value: min.map(|size| ScalarValue::UInt64(Some(size))),
max_value: max.map(|size| ScalarValue::UInt64(Some(size))),
min_value: min.map(|size| ScalarValue::Int64(Some(size))),
max_value: max.map(|size| ScalarValue::Int64(Some(size))),
..Default::default()
}
}

type PartialStats = (usize, u64, u64, Option<usize>);
type PartialStats = (usize, Option<i64>, Option<i64>, Option<usize>);

// This is mainly for validating the all edge cases of the estimation, but
// more advanced (and real world test cases) are below where we need some control
Expand All @@ -650,40 +697,135 @@ mod tests {
// | left(rows, min, max, distinct), right(rows, min, max, distinct), expected |
// -----------------------------------------------------------------------------

// distinct(left) is None OR distinct(right) is None
// Cardinality computation
// =======================
//
// distinct(left) == NaN, distinct(right) == NaN
(
(10, Some(1), Some(10), None),
(10, Some(1), Some(10), None),
Some(10),
),
// range(left) > range(right)
(
(10, Some(6), Some(10), None),
(10, Some(8), Some(10), None),
Some(20),
),
// range(right) > range(left)
(
(10, Some(8), Some(10), None),
(10, Some(6), Some(10), None),
Some(20),
),
// range(left) > len(left), range(right) > len(right)
(
(10, Some(1), Some(15), None),
(20, Some(1), Some(40), None),
Some(10),
),
// When we have distinct count.
(
(10, Some(1), Some(10), Some(10)),
(10, Some(1), Some(10), Some(10)),
Some(10),
),
// distinct(left) > distinct(right)
(
(10, Some(1), Some(10), Some(5)),
(10, Some(1), Some(10), Some(2)),
Some(20),
),
// distinct(right) > distinct(left)
(
(10, Some(1), Some(10), Some(2)),
(10, Some(1), Some(10), Some(5)),
Some(20),
),
// min(left) < 0 (range(left) > range(right))
(
(10, Some(-5), Some(5), None),
(10, Some(1), Some(5), None),
Some(10),
),
// min(right) < 0, max(right) < 0 (range(right) > range(left))
(
(10, Some(-25), Some(-20), None),
(10, Some(-25), Some(-15), None),
Some(10),
),
// range(left) < 0, range(right) >= 0
// (there isn't a case where both left and right ranges are negative
// so one of them is always going to work, this just proves negative
// ranges with bigger absolute values are not are not accidentally used).
(
(10, Some(10), Some(0), None),
(10, Some(0), Some(10), Some(5)),
Some(20), // It would have been ten if we have used abs(range(left))
),
// range(left) = 1, range(right) = 1
(
(10, Some(1), Some(1), None),
(10, Some(1), Some(1), None),
Some(100),
),
//
// len(left) = len(right), len(left) * len(right)
((10, 0, 10, None), (10, 0, 10, None), None),
// len(left) > len(right) OR len(left) < len(right), len(left) * len(right)
((10, 0, 10, None), (5, 0, 10, None), None),
((5, 0, 10, None), (10, 0, 10, None), None),
((10, 0, 10, None), (5, 0, 10, None), None),
((5, 0, 10, None), (10, 0, 10, None), None),
// min(left) > max(right) OR min(right) > max(left), None
((10, 0, 10, None), (10, 11, 20, None), None),
((10, 11, 20, None), (10, 0, 10, None), None),
((10, 5, 10, None), (10, 11, 3, None), None),
((10, 10, 5, None), (10, 3, 7, None), None),
// distinct(left) is not None AND distinct(right) is not None
// Edge cases
// ==========
//
// len(left) = len(right), len(left) * len(right) / max(distinct(left), distinct(right))
((10, 0, 10, Some(5)), (10, 0, 10, Some(5)), Some(20)),
((10, 0, 10, Some(10)), (10, 0, 10, Some(5)), Some(10)),
((10, 0, 10, Some(5)), (10, 0, 10, Some(10)), Some(10)),
// No column level stats.
((10, None, None, None), (10, None, None, None), None),
// No min or max (or both).
((10, None, None, Some(3)), (10, None, None, Some(3)), None),
(
(10, Some(2), None, Some(3)),
(10, None, Some(5), Some(3)),
None,
),
(
(10, None, Some(3), Some(3)),
(10, Some(1), None, Some(3)),
None,
),
((10, None, Some(3), None), (10, Some(1), None, None), None),
// Non overlapping min/max.
(
(10, Some(0), Some(10), None),
(10, Some(11), Some(20), None),
None,
),
(
(10, Some(11), Some(20), None),
(10, Some(0), Some(10), None),
None,
),
(
(10, Some(5), Some(10), Some(10)),
(10, Some(11), Some(3), Some(10)),
None,
),
(
(10, Some(10), Some(5), Some(10)),
(10, Some(3), Some(7), Some(10)),
None,
),
// distinct(left) = 0, distinct(right) = 0
(
(10, Some(1), Some(10), Some(0)),
(10, Some(1), Some(10), Some(0)),
None,
),
];

for (left_info, right_info, expected_cardinality) in cases {
let left_num_rows = left_info.0;
let left_col_stats = vec![create_column_stats(
Some(left_info.1),
Some(left_info.2),
left_info.3,
)];
let left_col_stats =
vec![create_column_stats(left_info.1, left_info.2, left_info.3)];

let right_num_rows = right_info.0;
let right_col_stats = vec![create_column_stats(
Some(right_info.1),
Some(right_info.2),
right_info.1,
right_info.2,
right_info.3,
)];

Expand Down Expand Up @@ -740,6 +882,29 @@ mod tests {
Ok(())
}

#[test]
fn test_inner_join_cardinality_decimal_range() -> Result<()> {
let left_col_stats = vec![ColumnStatistics {
distinct_count: None,
min_value: Some(ScalarValue::Decimal128(Some(32500), 14, 4)),
max_value: Some(ScalarValue::Decimal128(Some(35000), 14, 4)),
..Default::default()
}];

let right_col_stats = vec![ColumnStatistics {
distinct_count: None,
min_value: Some(ScalarValue::Decimal128(Some(33500), 14, 4)),
max_value: Some(ScalarValue::Decimal128(Some(34000), 14, 4)),
..Default::default()
}];

assert_eq!(
estimate_inner_join_cardinality(100, 100, left_col_stats, right_col_stats),
None
);
Ok(())
}

#[test]
fn test_join_cardinality() -> Result<()> {
// Left table (rows=1000)
Expand Down

0 comments on commit fe0000e

Please sign in to comment.