Skip to content

Commit

Permalink
Remove type coercions from ScalarValue and aggregation function code (#…
Browse files Browse the repository at this point in the history
…3705)

* Sanitize ScalarValue and aggregation code from type coercions

* Remove forced type cast from sum_row! macro used in SumRowAccumulator
  • Loading branch information
ozankabak authored Oct 6, 2022
1 parent 88eadc4 commit 8dcef91
Show file tree
Hide file tree
Showing 11 changed files with 334 additions and 795 deletions.
470 changes: 192 additions & 278 deletions datafusion/common/src/scalar.rs

Large diffs are not rendered by default.

66 changes: 11 additions & 55 deletions datafusion/physical-expr/src/aggregate/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ impl RowAccumulator for AvgRowAccumulator {

// sum
sum::add_to_row(
&self.sum_datatype,
self.state_index() + 1,
accessor,
&sum::sum_batch(values, &self.sum_datatype)?,
Expand All @@ -249,12 +248,8 @@ impl RowAccumulator for AvgRowAccumulator {
accessor.add_u64(self.state_index(), delta);

// sum
sum::add_to_row(
&self.sum_datatype,
self.state_index() + 1,
accessor,
&sum::sum_batch(&states[1], &self.sum_datatype)?,
)?;
let difference = sum::sum_batch(&states[1], &self.sum_datatype)?;
sum::add_to_row(self.state_index() + 1, accessor, &difference)?;
Ok(())
}

Expand Down Expand Up @@ -301,8 +296,7 @@ mod tests {
array,
DataType::Decimal128(10, 0),
Avg,
ScalarValue::Decimal128(Some(35000), 14, 4),
DataType::Decimal128(14, 4)
ScalarValue::Decimal128(Some(35000), 14, 4)
)
}

Expand All @@ -318,8 +312,7 @@ mod tests {
array,
DataType::Decimal128(10, 0),
Avg,
ScalarValue::Decimal128(Some(32500), 14, 4),
DataType::Decimal128(14, 4)
ScalarValue::Decimal128(Some(32500), 14, 4)
)
}

Expand All @@ -337,21 +330,14 @@ mod tests {
array,
DataType::Decimal128(10, 0),
Avg,
ScalarValue::Decimal128(None, 14, 4),
DataType::Decimal128(14, 4)
ScalarValue::Decimal128(None, 14, 4)
)
}

#[test]
fn avg_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
generic_test_op!(
a,
DataType::Int32,
Avg,
ScalarValue::from(3_f64),
DataType::Float64
)
generic_test_op!(a, DataType::Int32, Avg, ScalarValue::from(3_f64))
}

#[test]
Expand All @@ -363,63 +349,33 @@ mod tests {
Some(4),
Some(5),
]));
generic_test_op!(
a,
DataType::Int32,
Avg,
ScalarValue::from(3.25f64),
DataType::Float64
)
generic_test_op!(a, DataType::Int32, Avg, ScalarValue::from(3.25f64))
}

#[test]
fn avg_i32_all_nulls() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
generic_test_op!(
a,
DataType::Int32,
Avg,
ScalarValue::Float64(None),
DataType::Float64
)
generic_test_op!(a, DataType::Int32, Avg, ScalarValue::Float64(None))
}

#[test]
fn avg_u32() -> Result<()> {
let a: ArrayRef =
Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]));
generic_test_op!(
a,
DataType::UInt32,
Avg,
ScalarValue::from(3.0f64),
DataType::Float64
)
generic_test_op!(a, DataType::UInt32, Avg, ScalarValue::from(3.0f64))
}

#[test]
fn avg_f32() -> Result<()> {
let a: ArrayRef =
Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]));
generic_test_op!(
a,
DataType::Float32,
Avg,
ScalarValue::from(3_f64),
DataType::Float64
)
generic_test_op!(a, DataType::Float32, Avg, ScalarValue::from(3_f64))
}

#[test]
fn avg_f64() -> Result<()> {
let a: ArrayRef =
Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
generic_test_op!(
a,
DataType::Float64,
Avg,
ScalarValue::from(3_f64),
DataType::Float64
)
generic_test_op!(a, DataType::Float64, Avg, ScalarValue::from(3_f64))
}
}
24 changes: 8 additions & 16 deletions datafusion/physical-expr/src/aggregate/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,7 @@ mod tests {
DataType::Float64,
DataType::Float64,
Correlation,
ScalarValue::from(0.9819805060619659),
DataType::Float64
ScalarValue::from(0.9819805060619659_f64)
)
}

Expand All @@ -233,8 +232,7 @@ mod tests {
DataType::Float64,
DataType::Float64,
Correlation,
ScalarValue::from(0.17066403719657236),
DataType::Float64
ScalarValue::from(0.17066403719657236_f64)
)
}

Expand All @@ -249,8 +247,7 @@ mod tests {
DataType::Float64,
DataType::Float64,
Correlation,
ScalarValue::from(1_f64),
DataType::Float64
ScalarValue::from(1_f64)
)
}

Expand All @@ -269,8 +266,7 @@ mod tests {
DataType::Float64,
DataType::Float64,
Correlation,
ScalarValue::from(0.9860135594710389),
DataType::Float64
ScalarValue::from(0.9860135594710389_f64)
)
}

Expand All @@ -285,8 +281,7 @@ mod tests {
DataType::Int32,
DataType::Int32,
Correlation,
ScalarValue::from(1_f64),
DataType::Float64
ScalarValue::from(1_f64)
)
}

Expand All @@ -300,8 +295,7 @@ mod tests {
DataType::UInt32,
DataType::UInt32,
Correlation,
ScalarValue::from(1_f64),
DataType::Float64
ScalarValue::from(1_f64)
)
}

Expand All @@ -315,8 +309,7 @@ mod tests {
DataType::Float32,
DataType::Float32,
Correlation,
ScalarValue::from(1_f64),
DataType::Float64
ScalarValue::from(1_f64)
)
}

Expand All @@ -333,8 +326,7 @@ mod tests {
DataType::Int32,
DataType::Int32,
Correlation,
ScalarValue::from(0.1889822365046137),
DataType::Float64
ScalarValue::from(0.1889822365046137_f64)
)
}

Expand Down
48 changes: 6 additions & 42 deletions datafusion/physical-expr/src/aggregate/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,7 @@ mod tests {
#[test]
fn count_elements() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
generic_test_op!(
a,
DataType::Int32,
Count,
ScalarValue::from(5i64),
DataType::Int64
)
generic_test_op!(a, DataType::Int32, Count, ScalarValue::from(5i64))
}

#[test]
Expand All @@ -229,65 +223,35 @@ mod tests {
Some(3),
None,
]));
generic_test_op!(
a,
DataType::Int32,
Count,
ScalarValue::from(3i64),
DataType::Int64
)
generic_test_op!(a, DataType::Int32, Count, ScalarValue::from(3i64))
}

#[test]
fn count_all_nulls() -> Result<()> {
let a: ArrayRef = Arc::new(BooleanArray::from(vec![
None, None, None, None, None, None, None, None,
]));
generic_test_op!(
a,
DataType::Boolean,
Count,
ScalarValue::from(0i64),
DataType::Int64
)
generic_test_op!(a, DataType::Boolean, Count, ScalarValue::from(0i64))
}

#[test]
fn count_empty() -> Result<()> {
let a: Vec<bool> = vec![];
let a: ArrayRef = Arc::new(BooleanArray::from(a));
generic_test_op!(
a,
DataType::Boolean,
Count,
ScalarValue::from(0i64),
DataType::Int64
)
generic_test_op!(a, DataType::Boolean, Count, ScalarValue::from(0i64))
}

#[test]
fn count_utf8() -> Result<()> {
let a: ArrayRef =
Arc::new(StringArray::from(vec!["a", "bb", "ccc", "dddd", "ad"]));
generic_test_op!(
a,
DataType::Utf8,
Count,
ScalarValue::from(5i64),
DataType::Int64
)
generic_test_op!(a, DataType::Utf8, Count, ScalarValue::from(5i64))
}

#[test]
fn count_large_utf8() -> Result<()> {
let a: ArrayRef =
Arc::new(LargeStringArray::from(vec!["a", "bb", "ccc", "dddd", "ad"]));
generic_test_op!(
a,
DataType::LargeUtf8,
Count,
ScalarValue::from(5i64),
DataType::Int64
)
generic_test_op!(a, DataType::LargeUtf8, Count, ScalarValue::from(5i64))
}
}
27 changes: 9 additions & 18 deletions datafusion/physical-expr/src/aggregate/covariance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,7 @@ mod tests {
DataType::Float64,
DataType::Float64,
CovariancePop,
ScalarValue::from(0.6666666666666666),
DataType::Float64
ScalarValue::from(0.6666666666666666_f64)
)
}

Expand All @@ -413,8 +412,7 @@ mod tests {
DataType::Float64,
DataType::Float64,
Covariance,
ScalarValue::from(1_f64),
DataType::Float64
ScalarValue::from(1_f64)
)
}

Expand All @@ -429,8 +427,7 @@ mod tests {
DataType::Float64,
DataType::Float64,
Covariance,
ScalarValue::from(0.9033333333333335_f64),
DataType::Float64
ScalarValue::from(0.9033333333333335_f64)
)
}

Expand All @@ -445,8 +442,7 @@ mod tests {
DataType::Float64,
DataType::Float64,
CovariancePop,
ScalarValue::from(0.6022222222222223_f64),
DataType::Float64
ScalarValue::from(0.6022222222222223_f64)
)
}

Expand All @@ -465,8 +461,7 @@ mod tests {
DataType::Float64,
DataType::Float64,
CovariancePop,
ScalarValue::from(0.7616666666666666),
DataType::Float64
ScalarValue::from(0.7616666666666666_f64)
)
}

Expand All @@ -481,8 +476,7 @@ mod tests {
DataType::Int32,
DataType::Int32,
CovariancePop,
ScalarValue::from(0.6666666666666666_f64),
DataType::Float64
ScalarValue::from(0.6666666666666666_f64)
)
}

Expand All @@ -496,8 +490,7 @@ mod tests {
DataType::UInt32,
DataType::UInt32,
CovariancePop,
ScalarValue::from(0.6666666666666666_f64),
DataType::Float64
ScalarValue::from(0.6666666666666666_f64)
)
}

Expand All @@ -511,8 +504,7 @@ mod tests {
DataType::Float32,
DataType::Float32,
CovariancePop,
ScalarValue::from(0.6666666666666666_f64),
DataType::Float64
ScalarValue::from(0.6666666666666666_f64)
)
}

Expand All @@ -527,8 +519,7 @@ mod tests {
DataType::Int32,
DataType::Int32,
CovariancePop,
ScalarValue::from(1_f64),
DataType::Float64
ScalarValue::from(1_f64)
)
}

Expand Down
Loading

0 comments on commit 8dcef91

Please sign in to comment.