diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index a7bb31b6e109..5c7c891f92d2 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -42,6 +42,7 @@ pub struct NthValue { /// Output data type data_type: DataType, kind: NthValueKind, + ignore_nulls: bool, } impl NthValue { @@ -50,12 +51,14 @@ impl NthValue { name: impl Into, expr: Arc, data_type: DataType, + ignore_nulls: bool, ) -> Self { Self { name: name.into(), expr, data_type, kind: NthValueKind::First, + ignore_nulls, } } @@ -64,12 +67,14 @@ impl NthValue { name: impl Into, expr: Arc, data_type: DataType, + ignore_nulls: bool, ) -> Self { Self { name: name.into(), expr, data_type, kind: NthValueKind::Last, + ignore_nulls, } } @@ -79,7 +84,11 @@ impl NthValue { expr: Arc, data_type: DataType, n: u32, + ignore_nulls: bool, ) -> Result { + if ignore_nulls { + return exec_err!("NTH_VALUE ignore_nulls is not supported yet"); + } match n { 0 => exec_err!("NTH_VALUE expects n to be non-zero"), _ => Ok(Self { @@ -87,6 +96,7 @@ impl NthValue { expr, data_type, kind: NthValueKind::Nth(n as i64), + ignore_nulls, }), } } @@ -122,7 +132,10 @@ impl BuiltInWindowFunctionExpr for NthValue { finalized_result: None, kind: self.kind, }; - Ok(Box::new(NthValueEvaluator { state })) + Ok(Box::new(NthValueEvaluator { + state, + ignore_nulls: self.ignore_nulls, + })) } fn reverse_expr(&self) -> Option> { @@ -136,6 +149,7 @@ impl BuiltInWindowFunctionExpr for NthValue { expr: self.expr.clone(), data_type: self.data_type.clone(), kind: reversed_kind, + ignore_nulls: self.ignore_nulls, })) } } @@ -144,6 +158,7 @@ impl BuiltInWindowFunctionExpr for NthValue { #[derive(Debug)] pub(crate) struct NthValueEvaluator { state: NthValueState, + ignore_nulls: bool, } impl PartitionEvaluator for NthValueEvaluator { @@ -184,7 +199,8 @@ impl PartitionEvaluator for NthValueEvaluator { } } }; - if is_prunable { + // Do not memoize results when nulls are ignored. + if is_prunable && !self.ignore_nulls { if self.state.finalized_result.is_none() && !is_reverse_direction { let result = ScalarValue::try_from_array(out, size - 1)?; self.state.finalized_result = Some(result); @@ -210,9 +226,39 @@ impl PartitionEvaluator for NthValueEvaluator { // We produce None if the window is empty. return ScalarValue::try_from(arr.data_type()); } + + // Extract valid indices if ignoring nulls. + let (slice, valid_indices) = if self.ignore_nulls { + let slice = arr.slice(range.start, n_range); + let valid_indices = + slice.nulls().unwrap().valid_indices().collect::>(); + if valid_indices.is_empty() { + return ScalarValue::try_from(arr.data_type()); + } + (Some(slice), Some(valid_indices)) + } else { + (None, None) + }; match self.state.kind { - NthValueKind::First => ScalarValue::try_from_array(arr, range.start), - NthValueKind::Last => ScalarValue::try_from_array(arr, range.end - 1), + NthValueKind::First => { + if let Some(slice) = &slice { + let valid_indices = valid_indices.unwrap(); + ScalarValue::try_from_array(slice, valid_indices[0]) + } else { + ScalarValue::try_from_array(arr, range.start) + } + } + NthValueKind::Last => { + if let Some(slice) = &slice { + let valid_indices = valid_indices.unwrap(); + ScalarValue::try_from_array( + slice, + valid_indices[valid_indices.len() - 1], + ) + } else { + ScalarValue::try_from_array(arr, range.end - 1) + } + } NthValueKind::Nth(n) => { match n.cmp(&0) { Ordering::Greater => { @@ -295,6 +341,7 @@ mod tests { "first_value".to_owned(), Arc::new(Column::new("arr", 0)), DataType::Int32, + false, ); test_i32_result(first_value, Int32Array::from(vec![1; 8]))?; Ok(()) @@ -306,6 +353,7 @@ mod tests { "last_value".to_owned(), Arc::new(Column::new("arr", 0)), DataType::Int32, + false, ); test_i32_result( last_value, @@ -330,6 +378,7 @@ mod tests { Arc::new(Column::new("arr", 0)), DataType::Int32, 1, + false, )?; test_i32_result(nth_value, Int32Array::from(vec![1; 8]))?; Ok(()) @@ -342,6 +391,7 @@ mod tests { Arc::new(Column::new("arr", 0)), DataType::Int32, 2, + false, )?; test_i32_result( nth_value, diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 4cba571054de..0349f8f1eeec 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -1179,15 +1179,19 @@ mod tests { .map(|e| Arc::new(e) as Arc)?; let col_a = col("a", &schema)?; let nth_value_func1 = - NthValue::nth("nth_value(-1)", col_a.clone(), DataType::Int32, 1)? + NthValue::nth("nth_value(-1)", col_a.clone(), DataType::Int32, 1, false)? .reverse_expr() .unwrap(); let nth_value_func2 = - NthValue::nth("nth_value(-2)", col_a.clone(), DataType::Int32, 2)? + NthValue::nth("nth_value(-2)", col_a.clone(), DataType::Int32, 2, false)? .reverse_expr() .unwrap(); - let last_value_func = - Arc::new(NthValue::last("last", col_a.clone(), DataType::Int32)) as _; + let last_value_func = Arc::new(NthValue::last( + "last", + col_a.clone(), + DataType::Int32, + false, + )) as _; let window_exprs = vec![ // LAST_VALUE(a) Arc::new(BuiltInWindowExpr::new( diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index f91b525d6090..6712bc855ffd 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -250,15 +250,21 @@ fn create_built_in_window_expr( .try_into() .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; let n: u32 = n as u32; - Arc::new(NthValue::nth(name, arg, data_type.clone(), n)?) + Arc::new(NthValue::nth( + name, + arg, + data_type.clone(), + n, + ignore_nulls, + )?) } BuiltInWindowFunction::FirstValue => { let arg = args[0].clone(); - Arc::new(NthValue::first(name, arg, data_type.clone())) + Arc::new(NthValue::first(name, arg, data_type.clone(), ignore_nulls)) } BuiltInWindowFunction::LastValue => { let arg = args[0].clone(); - Arc::new(NthValue::last(name, arg, data_type.clone())) + Arc::new(NthValue::last(name, arg, data_type.clone(), ignore_nulls)) } }) } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index a3c0b3eccd3c..004261eff595 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -271,6 +271,7 @@ fn roundtrip_window() -> Result<()> { "FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", col("a", &schema)?, DataType::Int64, + false, )), &[col("b", &schema)?], &[PhysicalSortExpr { diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index cce67d898d37..39c105a4dcce 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -4307,3 +4307,275 @@ select lag(a) over (order by a ASC NULLS FIRST) as x1 NULL NULL NULL + +# Test for ignore nulls in FIRST_VALUE +statement ok +CREATE TABLE t AS VALUES (null::bigint), (3), (4); + +query I +SELECT FIRST_VALUE(column1) OVER() FROM t; +---- +NULL +NULL +NULL + +query I +SELECT FIRST_VALUE(column1) RESPECT NULLS OVER() FROM t; +---- +NULL +NULL +NULL + +query I +SELECT FIRST_VALUE(column1) IGNORE NULLS OVER() FROM t; +---- +3 +3 +3 + +statement ok +DROP TABLE t; + +# Test for ignore nulls with ORDER BY in FIRST_VALUE +statement ok +CREATE TABLE t AS VALUES (3, 4), (4, 3), (null::bigint, 1), (null::bigint, 2), (5, 5), (6, 6); + +query II +SELECT column1, column2 FROM t ORDER BY column2; +---- +NULL 1 +NULL 2 +4 3 +3 4 +5 5 +6 6 + +query II +SELECT FIRST_VALUE(column1) OVER(ORDER BY column2), column2 FROM t; +---- +NULL 1 +NULL 2 +NULL 3 +NULL 4 +NULL 5 +NULL 6 + +query II +SELECT FIRST_VALUE(column1) RESPECT NULLS OVER(ORDER BY column2), column2 FROM t; +---- +NULL 1 +NULL 2 +NULL 3 +NULL 4 +NULL 5 +NULL 6 + +query II +SELECT FIRST_VALUE(column1) IGNORE NULLS OVER(ORDER BY column2), column2 FROM t; +---- +NULL 1 +NULL 2 +4 3 +4 4 +4 5 +4 6 + +query II +SELECT FIRST_VALUE(column1)OVER(ORDER BY column2 RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING), column2 FROM t; +---- +NULL 1 +NULL 2 +NULL 3 +4 4 +3 5 +5 6 + +query II +SELECT FIRST_VALUE(column1) IGNORE NULLS OVER(ORDER BY column2 RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING), column2 FROM t; +---- +NULL 1 +4 2 +4 3 +4 4 +3 5 +5 6 + +statement ok +DROP TABLE t; + +# Test for ignore nulls with ORDER BY in FIRST_VALUE with all NULL values +statement ok +CREATE TABLE t AS VALUES (null::bigint, 4), (null::bigint, 3), (null::bigint, 1), (null::bigint, 2); + +query II +SELECT FIRST_VALUE(column1) OVER(ORDER BY column2), column2 FROM t; +---- +NULL 1 +NULL 2 +NULL 3 +NULL 4 + +query II +SELECT FIRST_VALUE(column1) RESPECT NULLS OVER(ORDER BY column2), column2 FROM t; +---- +NULL 1 +NULL 2 +NULL 3 +NULL 4 + +query II +SELECT FIRST_VALUE(column1) IGNORE NULLS OVER(ORDER BY column2), column2 FROM t; +---- +NULL 1 +NULL 2 +NULL 3 +NULL 4 + +statement ok +DROP TABLE t; + +# Test for ignore nulls in LAST_VALUE +statement ok +CREATE TABLE t AS VALUES (1), (3), (null::bigint); + +query I +SELECT LAST_VALUE(column1) OVER() FROM t; +---- +NULL +NULL +NULL + +query I +SELECT LAST_VALUE(column1) RESPECT NULLS OVER() FROM t; +---- +NULL +NULL +NULL + +query I +SELECT LAST_VALUE(column1) IGNORE NULLS OVER() FROM t; +---- +3 +3 +3 + +statement ok +DROP TABLE t; + +# Test for ignore nulls with ORDER BY in LAST_VALUE +statement ok +CREATE TABLE t AS VALUES (3, 4), (4, 3), (null::bigint, 1), (null::bigint, 2), (5, 5), (6, 6); + +query II +SELECT column1, column2 FROM t ORDER BY column2 DESC NULLS LAST; +---- +6 6 +5 5 +3 4 +4 3 +NULL 2 +NULL 1 + +query II +SELECT LAST_VALUE(column1) OVER(ORDER BY column2 DESC NULLS LAST), column2 FROM t; +---- +6 6 +5 5 +3 4 +4 3 +NULL 2 +NULL 1 + +query II +SELECT LAST_VALUE(column1) IGNORE NULLS OVER(ORDER BY column2 DESC NULLS LAST), column2 FROM t; +---- +6 6 +5 5 +3 4 +4 3 +4 2 +4 1 + +query II +SELECT LAST_VALUE(column1) OVER(ORDER BY column2 DESC NULLS LAST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), column2 FROM t; +---- +NULL 6 +NULL 5 +NULL 4 +NULL 3 +NULL 2 +NULL 1 + +query II +SELECT LAST_VALUE(column1) RESPECT NULLS OVER(ORDER BY column2 DESC NULLS LAST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), column2 FROM t; +---- +NULL 6 +NULL 5 +NULL 4 +NULL 3 +NULL 2 +NULL 1 + +query II +SELECT LAST_VALUE(column1) IGNORE NULLS OVER(ORDER BY column2 DESC NULLS LAST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), column2 FROM t; +---- +4 6 +4 5 +4 4 +4 3 +4 2 +4 1 + +query II +SELECT LAST_VALUE(column1) OVER(ORDER BY column2 DESC NULLS LAST RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING), column2 FROM t; +---- +5 6 +3 5 +4 4 +NULL 3 +NULL 2 +NULL 1 + +query II +SELECT LAST_VALUE(column1) IGNORE NULLS OVER(ORDER BY column2 DESC NULLS LAST RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING), column2 FROM t; +---- +5 6 +3 5 +4 4 +4 3 +4 2 +NULL 1 + +statement ok +DROP TABLE t; + +# Test for ignore nulls with ORDER BY in LAST_VALUE with all NULLs +statement ok +CREATE TABLE t AS VALUES (null::bigint, 4), (null::bigint, 3), (null::bigint, 1), (null::bigint, 2); + +query II +SELECT LAST_VALUE(column1) OVER(ORDER BY column2 DESC NULLS LAST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), column2 FROM t; +---- +NULL 4 +NULL 3 +NULL 2 +NULL 1 + +query II +SELECT LAST_VALUE(column1) RESPECT NULLS OVER(ORDER BY column2 DESC NULLS LAST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), column2 FROM t; +---- +NULL 4 +NULL 3 +NULL 2 +NULL 1 + +query II +SELECT LAST_VALUE(column1) IGNORE NULLS OVER(ORDER BY column2 DESC NULLS LAST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), column2 FROM t; +---- +NULL 4 +NULL 3 +NULL 2 +NULL 1 + +statement ok +DROP TABLE t;