Skip to content

Commit

Permalink
Support IGNORE NULLS for FIRST/LAST window function (#9470)
Browse files Browse the repository at this point in the history
* Support IGNORE NULLS for FIRST/LAST window function

* fix error

* fix style error

* fix clippy error

* add tests for all NULL values

* address comments

* fix format

* address comments

* fix format

* Fix commented test case

* resolve conflicts

---------

Co-authored-by: Huaxin Gao <[email protected]>
Co-authored-by: Mustafa Akur <[email protected]>
  • Loading branch information
3 people authored Mar 12, 2024
1 parent ef9bc90 commit 92d046d
Show file tree
Hide file tree
Showing 5 changed files with 344 additions and 11 deletions.
58 changes: 54 additions & 4 deletions datafusion/physical-expr/src/window/nth_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub struct NthValue {
/// Output data type
data_type: DataType,
kind: NthValueKind,
ignore_nulls: bool,
}

impl NthValue {
Expand All @@ -50,12 +51,14 @@ impl NthValue {
name: impl Into<String>,
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
ignore_nulls: bool,
) -> Self {
Self {
name: name.into(),
expr,
data_type,
kind: NthValueKind::First,
ignore_nulls,
}
}

Expand All @@ -64,12 +67,14 @@ impl NthValue {
name: impl Into<String>,
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
ignore_nulls: bool,
) -> Self {
Self {
name: name.into(),
expr,
data_type,
kind: NthValueKind::Last,
ignore_nulls,
}
}

Expand All @@ -79,14 +84,19 @@ impl NthValue {
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
n: u32,
ignore_nulls: bool,
) -> Result<Self> {
if ignore_nulls {
return exec_err!("NTH_VALUE ignore_nulls is not supported yet");
}
match n {
0 => exec_err!("NTH_VALUE expects n to be non-zero"),
_ => Ok(Self {
name: name.into(),
expr,
data_type,
kind: NthValueKind::Nth(n as i64),
ignore_nulls,
}),
}
}
Expand Down Expand Up @@ -122,7 +132,10 @@ impl BuiltInWindowFunctionExpr for NthValue {
finalized_result: None,
kind: self.kind,
};
Ok(Box::new(NthValueEvaluator { state }))
Ok(Box::new(NthValueEvaluator {
state,
ignore_nulls: self.ignore_nulls,
}))
}

fn reverse_expr(&self) -> Option<Arc<dyn BuiltInWindowFunctionExpr>> {
Expand All @@ -136,6 +149,7 @@ impl BuiltInWindowFunctionExpr for NthValue {
expr: self.expr.clone(),
data_type: self.data_type.clone(),
kind: reversed_kind,
ignore_nulls: self.ignore_nulls,
}))
}
}
Expand All @@ -144,6 +158,7 @@ impl BuiltInWindowFunctionExpr for NthValue {
#[derive(Debug)]
pub(crate) struct NthValueEvaluator {
state: NthValueState,
ignore_nulls: bool,
}

impl PartitionEvaluator for NthValueEvaluator {
Expand Down Expand Up @@ -184,7 +199,8 @@ impl PartitionEvaluator for NthValueEvaluator {
}
}
};
if is_prunable {
// Do not memoize results when nulls are ignored.
if is_prunable && !self.ignore_nulls {
if self.state.finalized_result.is_none() && !is_reverse_direction {
let result = ScalarValue::try_from_array(out, size - 1)?;
self.state.finalized_result = Some(result);
Expand All @@ -210,9 +226,39 @@ impl PartitionEvaluator for NthValueEvaluator {
// We produce None if the window is empty.
return ScalarValue::try_from(arr.data_type());
}

// Extract valid indices if ignoring nulls.
let (slice, valid_indices) = if self.ignore_nulls {
let slice = arr.slice(range.start, n_range);
let valid_indices =
slice.nulls().unwrap().valid_indices().collect::<Vec<_>>();
if valid_indices.is_empty() {
return ScalarValue::try_from(arr.data_type());
}
(Some(slice), Some(valid_indices))
} else {
(None, None)
};
match self.state.kind {
NthValueKind::First => ScalarValue::try_from_array(arr, range.start),
NthValueKind::Last => ScalarValue::try_from_array(arr, range.end - 1),
NthValueKind::First => {
if let Some(slice) = &slice {
let valid_indices = valid_indices.unwrap();
ScalarValue::try_from_array(slice, valid_indices[0])
} else {
ScalarValue::try_from_array(arr, range.start)
}
}
NthValueKind::Last => {
if let Some(slice) = &slice {
let valid_indices = valid_indices.unwrap();
ScalarValue::try_from_array(
slice,
valid_indices[valid_indices.len() - 1],
)
} else {
ScalarValue::try_from_array(arr, range.end - 1)
}
}
NthValueKind::Nth(n) => {
match n.cmp(&0) {
Ordering::Greater => {
Expand Down Expand Up @@ -295,6 +341,7 @@ mod tests {
"first_value".to_owned(),
Arc::new(Column::new("arr", 0)),
DataType::Int32,
false,
);
test_i32_result(first_value, Int32Array::from(vec![1; 8]))?;
Ok(())
Expand All @@ -306,6 +353,7 @@ mod tests {
"last_value".to_owned(),
Arc::new(Column::new("arr", 0)),
DataType::Int32,
false,
);
test_i32_result(
last_value,
Expand All @@ -330,6 +378,7 @@ mod tests {
Arc::new(Column::new("arr", 0)),
DataType::Int32,
1,
false,
)?;
test_i32_result(nth_value, Int32Array::from(vec![1; 8]))?;
Ok(())
Expand All @@ -342,6 +391,7 @@ mod tests {
Arc::new(Column::new("arr", 0)),
DataType::Int32,
2,
false,
)?;
test_i32_result(
nth_value,
Expand Down
12 changes: 8 additions & 4 deletions datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1179,15 +1179,19 @@ mod tests {
.map(|e| Arc::new(e) as Arc<dyn ExecutionPlan>)?;
let col_a = col("a", &schema)?;
let nth_value_func1 =
NthValue::nth("nth_value(-1)", col_a.clone(), DataType::Int32, 1)?
NthValue::nth("nth_value(-1)", col_a.clone(), DataType::Int32, 1, false)?
.reverse_expr()
.unwrap();
let nth_value_func2 =
NthValue::nth("nth_value(-2)", col_a.clone(), DataType::Int32, 2)?
NthValue::nth("nth_value(-2)", col_a.clone(), DataType::Int32, 2, false)?
.reverse_expr()
.unwrap();
let last_value_func =
Arc::new(NthValue::last("last", col_a.clone(), DataType::Int32)) as _;
let last_value_func = Arc::new(NthValue::last(
"last",
col_a.clone(),
DataType::Int32,
false,
)) as _;
let window_exprs = vec![
// LAST_VALUE(a)
Arc::new(BuiltInWindowExpr::new(
Expand Down
12 changes: 9 additions & 3 deletions datafusion/physical-plan/src/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,15 +250,21 @@ fn create_built_in_window_expr(
.try_into()
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
let n: u32 = n as u32;
Arc::new(NthValue::nth(name, arg, data_type.clone(), n)?)
Arc::new(NthValue::nth(
name,
arg,
data_type.clone(),
n,
ignore_nulls,
)?)
}
BuiltInWindowFunction::FirstValue => {
let arg = args[0].clone();
Arc::new(NthValue::first(name, arg, data_type.clone()))
Arc::new(NthValue::first(name, arg, data_type.clone(), ignore_nulls))
}
BuiltInWindowFunction::LastValue => {
let arg = args[0].clone();
Arc::new(NthValue::last(name, arg, data_type.clone()))
Arc::new(NthValue::last(name, arg, data_type.clone(), ignore_nulls))
}
})
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ fn roundtrip_window() -> Result<()> {
"FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW",
col("a", &schema)?,
DataType::Int64,
false,
)),
&[col("b", &schema)?],
&[PhysicalSortExpr {
Expand Down
Loading

0 comments on commit 92d046d

Please sign in to comment.