diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index c519d41aad01..fa9523a76380 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -19,6 +19,8 @@ use arrow::{array::ArrayRef, datatypes::Schema}; use arrow_array::BooleanArray; use arrow_schema::FieldRef; use datafusion_common::{Column, ScalarValue}; +use parquet::basic::Type; +use parquet::data_type::Decimal; use parquet::file::metadata::ColumnChunkMetaData; use parquet::schema::types::SchemaDescriptor; use parquet::{ @@ -143,7 +145,10 @@ pub(crate) async fn prune_row_groups_by_bloom_filters< continue; } }; - column_sbbf.insert(column_name.to_string(), bf); + let physical_type = + builder.parquet_schema().column(column_idx).physical_type(); + + column_sbbf.insert(column_name.to_string(), (bf, physical_type)); } let stats = BloomFilterStatistics { column_sbbf }; @@ -169,8 +174,8 @@ pub(crate) async fn prune_row_groups_by_bloom_filters< /// Implements `PruningStatistics` for Parquet Split Block Bloom Filters (SBBF) struct BloomFilterStatistics { - /// Maps column name to the parquet bloom filter - column_sbbf: HashMap, + /// Maps column name to the parquet bloom filter and parquet physical type + column_sbbf: HashMap, } impl PruningStatistics for BloomFilterStatistics { @@ -200,7 +205,7 @@ impl PruningStatistics for BloomFilterStatistics { column: &Column, values: &HashSet, ) -> Option { - let sbbf = self.column_sbbf.get(column.name.as_str())?; + let (sbbf, parquet_type) = self.column_sbbf.get(column.name.as_str())?; // Bloom filters are probabilistic data structures that can return false // positives (i.e. it might return true even if the value is not @@ -209,16 +214,63 @@ impl PruningStatistics for BloomFilterStatistics { let known_not_present = values .iter() - .map(|value| match value { - ScalarValue::Utf8(Some(v)) => sbbf.check(&v.as_str()), - ScalarValue::Boolean(Some(v)) => sbbf.check(v), - ScalarValue::Float64(Some(v)) => sbbf.check(v), - ScalarValue::Float32(Some(v)) => sbbf.check(v), - ScalarValue::Int64(Some(v)) => sbbf.check(v), - ScalarValue::Int32(Some(v)) => sbbf.check(v), - ScalarValue::Int16(Some(v)) => sbbf.check(v), - ScalarValue::Int8(Some(v)) => sbbf.check(v), - _ => true, + .map(|value| { + match value { + ScalarValue::Utf8(Some(v)) => sbbf.check(&v.as_str()), + ScalarValue::Boolean(Some(v)) => sbbf.check(v), + ScalarValue::Float64(Some(v)) => sbbf.check(v), + ScalarValue::Float32(Some(v)) => sbbf.check(v), + ScalarValue::Int64(Some(v)) => sbbf.check(v), + ScalarValue::Int32(Some(v)) => sbbf.check(v), + ScalarValue::Int16(Some(v)) => sbbf.check(v), + ScalarValue::Int8(Some(v)) => sbbf.check(v), + ScalarValue::Decimal128(Some(v), p, s) => match parquet_type { + Type::INT32 => { + //https://github.com/apache/parquet-format/blob/eb4b31c1d64a01088d02a2f9aefc6c17c54cc6fc/Encodings.md?plain=1#L35-L42 + // All physical type are little-endian + if *p > 9 { + //DECIMAL can be used to annotate the following types: + // + // int32: for 1 <= precision <= 9 + // int64: for 1 <= precision <= 18 + return true; + } + let b = (*v as i32).to_le_bytes(); + // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 + let decimal = Decimal::Int32 { + value: b, + precision: *p as i32, + scale: *s as i32, + }; + sbbf.check(&decimal) + } + Type::INT64 => { + if *p > 18 { + return true; + } + let b = (*v as i64).to_le_bytes(); + let decimal = Decimal::Int64 { + value: b, + precision: *p as i32, + scale: *s as i32, + }; + sbbf.check(&decimal) + } + Type::FIXED_LEN_BYTE_ARRAY => { + // keep with from_bytes_to_i128 + let b = v.to_be_bytes().to_vec(); + // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 + let decimal = Decimal::Bytes { + value: b.into(), + precision: *p as i32, + scale: *s as i32, + }; + sbbf.check(&decimal) + } + _ => true, + }, + _ => true, + } }) // The row group doesn't contain any of the values if // all the checks are false diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 0602b4d4c525..b056db6a0bd3 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -66,7 +66,10 @@ enum Scenario { Int32Range, Float64, Decimal, + DecimalBloomFilterInt32, + DecimalBloomFilterInt64, DecimalLargePrecision, + DecimalLargePrecisionBloomFilter, PeriodsInColumnNames, } @@ -549,6 +552,22 @@ fn create_data_batch(scenario: Scenario) -> Vec { make_decimal_batch(vec![2000, 3000, 3000, 4000, 6000], 9, 2), ] } + Scenario::DecimalBloomFilterInt32 => { + // decimal record batch + vec![ + make_decimal_batch(vec![100, 200, 300, 400, 500], 6, 2), + make_decimal_batch(vec![100, 200, 300, 400, 600], 6, 2), + make_decimal_batch(vec![100, 200, 300, 400, 600], 6, 2), + ] + } + Scenario::DecimalBloomFilterInt64 => { + // decimal record batch + vec![ + make_decimal_batch(vec![100, 200, 300, 400, 500], 9, 2), + make_decimal_batch(vec![100, 200, 300, 400, 600], 9, 2), + make_decimal_batch(vec![100, 200, 300, 400, 600], 9, 2), + ] + } Scenario::DecimalLargePrecision => { // decimal record batch with large precision, // and the data will stored as FIXED_LENGTH_BYTE_ARRAY @@ -558,6 +577,15 @@ fn create_data_batch(scenario: Scenario) -> Vec { make_decimal_batch(vec![2000, 3000, 3000, 4000, 6000], 38, 2), ] } + Scenario::DecimalLargePrecisionBloomFilter => { + // decimal record batch with large precision, + // and the data will stored as FIXED_LENGTH_BYTE_ARRAY + vec![ + make_decimal_batch(vec![100000, 200000, 300000, 400000, 500000], 38, 5), + make_decimal_batch(vec![-100000, 200000, 300000, 400000, 600000], 38, 5), + make_decimal_batch(vec![100000, 200000, 300000, 400000, 600000], 38, 5), + ] + } Scenario::PeriodsInColumnNames => { vec![ // all frontend diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index c8cac5dd9b7a..449a311777dc 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -599,6 +599,39 @@ async fn prune_decimal_in_list() { .with_expected_rows(6) .test_row_group_prune() .await; + + // test data -> r1: {1,2,3,4,5}, r2: {1,2,3,4,6}, r3: {1,2,3,4,6} + RowGroupPruningTest::new() + .with_scenario(Scenario::DecimalBloomFilterInt32) + .with_query("SELECT * FROM t where decimal_col in (5)") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(0)) + .with_pruned_by_bloom_filter(Some(2)) + .with_expected_rows(1) + .test_row_group_prune() + .await; + + // test data -> r1: {1,2,3,4,5}, r2: {1,2,3,4,6}, r3: {1,2,3,4,6} + RowGroupPruningTest::new() + .with_scenario(Scenario::DecimalBloomFilterInt64) + .with_query("SELECT * FROM t where decimal_col in (5)") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(0)) + .with_pruned_by_bloom_filter(Some(2)) + .with_expected_rows(1) + .test_row_group_prune() + .await; + + // test data -> r1: {1,2,3,4,5}, r2: {1,2,3,4,6}, r3: {1,2,3,4,6} + RowGroupPruningTest::new() + .with_scenario(Scenario::DecimalLargePrecisionBloomFilter) + .with_query("SELECT * FROM t where decimal_col in (5)") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(0)) + .with_pruned_by_bloom_filter(Some(2)) + .with_expected_rows(1) + .test_row_group_prune() + .await; } #[tokio::test]