diff --git a/src/common/recordbatch/src/lib.rs b/src/common/recordbatch/src/lib.rs index 23aa04a9bf61..77987eac254c 100644 --- a/src/common/recordbatch/src/lib.rs +++ b/src/common/recordbatch/src/lib.rs @@ -66,7 +66,7 @@ impl Stream for EmptyRecordBatchStream { } } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct RecordBatches { schema: SchemaRef, batches: Vec, diff --git a/src/common/recordbatch/src/recordbatch.rs b/src/common/recordbatch/src/recordbatch.rs index 47c86831dc0e..c8bbc94c8db2 100644 --- a/src/common/recordbatch/src/recordbatch.rs +++ b/src/common/recordbatch/src/recordbatch.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use datatypes::arrow::util::pretty; use datatypes::schema::SchemaRef; use datatypes::value::Value; use datatypes::vectors::{Helper, VectorRef}; @@ -98,6 +99,13 @@ impl RecordBatch { pub fn rows(&self) -> RecordBatchRowIterator<'_> { RecordBatchRowIterator::new(self) } + + pub fn pretty_print(&self) -> Result { + let df_batch = self.df_record_batch.clone(); + let result = pretty::pretty_format_batches(&[df_batch]).context(error::FormatSnafu)?; + + Ok(result.to_string()) + } } impl Serialize for RecordBatch { diff --git a/src/common/recordbatch/src/util.rs b/src/common/recordbatch/src/util.rs index 1cca3ee9889b..4b2f1a67c84d 100644 --- a/src/common/recordbatch/src/util.rs +++ b/src/common/recordbatch/src/util.rs @@ -15,13 +15,20 @@ use futures::TryStreamExt; use crate::error::Result; -use crate::{RecordBatch, SendableRecordBatchStream}; +use crate::{RecordBatch, RecordBatches, SendableRecordBatchStream}; /// Collect all the items from the stream into a vector of [`RecordBatch`]. pub async fn collect(stream: SendableRecordBatchStream) -> Result> { stream.try_collect::>().await } +/// Collect all the items from the stream into [RecordBatches]. +pub async fn collect_batches(stream: SendableRecordBatchStream) -> Result { + let schema = stream.schema(); + let batches = stream.try_collect::>().await?; + RecordBatches::try_new(schema, batches) +} + #[cfg(test)] mod tests { use std::mem; @@ -90,7 +97,14 @@ mod tests { }; let batches = collect(Box::pin(stream)).await.unwrap(); assert_eq!(1, batches.len()); - assert_eq!(batch, batches[0]); + + let stream = MockRecordBatchStream { + schema: schema.clone(), + batch: Some(batch.clone()), + }; + let batches = collect_batches(Box::pin(stream)).await.unwrap(); + let expect_batches = RecordBatches::try_new(schema.clone(), vec![batch]).unwrap(); + assert_eq!(expect_batches, batches); } } diff --git a/src/query/tests/argmax_test.rs b/src/query/tests/argmax_test.rs index 11f0167a096c..cbf1ae931dc9 100644 --- a/src/query/tests/argmax_test.rs +++ b/src/query/tests/argmax_test.rs @@ -12,16 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; mod function; + +use std::sync::Arc; + use common_query::Output; use common_recordbatch::error::Result as RecordResult; use common_recordbatch::{util, RecordBatch}; -use datafusion::field_util::{FieldExt, SchemaExt}; use datatypes::for_all_primitive_types; use datatypes::prelude::*; -use datatypes::types::PrimitiveElement; -use function::{create_query_engine, get_numbers_from_table}; +use datatypes::types::WrapperType; use query::error::Result; use query::QueryEngine; use session::context::QueryContext; @@ -29,7 +29,7 @@ use session::context::QueryContext; #[tokio::test] async fn test_argmax_aggregator() -> Result<()> { common_telemetry::init_default_ut_logging(); - let engine = create_query_engine(); + let engine = function::create_query_engine(); macro_rules! test_argmax { ([], $( { $T:ty } ),*) => { @@ -49,33 +49,23 @@ async fn test_argmax_success( engine: Arc, ) -> Result<()> where - T: PrimitiveElement + PartialOrd, - for<'a> T: Scalar = T>, + T: WrapperType + PartialOrd, { let result = execute_argmax(column_name, table_name, engine.clone()) .await .unwrap(); - assert_eq!(1, result.len()); - assert_eq!(result[0].df_recordbatch.num_columns(), 1); - assert_eq!(1, result[0].schema.arrow_schema().fields().len()); - assert_eq!("argmax", result[0].schema.arrow_schema().field(0).name()); - - let columns = result[0].df_recordbatch.columns(); - assert_eq!(1, columns.len()); - assert_eq!(columns[0].len(), 1); - let v = VectorHelper::try_into_vector(&columns[0]).unwrap(); - assert_eq!(1, v.len()); - let value = v.get(0); + let value = function::get_value_from_batches("argmax", result); - let numbers = get_numbers_from_table::(column_name, table_name, engine.clone()).await; + let numbers = + function::get_numbers_from_table::(column_name, table_name, engine.clone()).await; let expected_value = match numbers.len() { 0 => 0_u64, _ => { let mut index = 0; - let mut max = numbers[0].into(); + let mut max = numbers[0]; for (i, &number) in numbers.iter().enumerate() { - if max < number.into() { - max = number.into(); + if max < number { + max = number; index = i; } } diff --git a/src/query/tests/argmin_test.rs b/src/query/tests/argmin_test.rs index 2a509f05fdc1..546fa9ae23f3 100644 --- a/src/query/tests/argmin_test.rs +++ b/src/query/tests/argmin_test.rs @@ -12,17 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; mod function; +use std::sync::Arc; + use common_query::Output; use common_recordbatch::error::Result as RecordResult; use common_recordbatch::{util, RecordBatch}; -use datafusion::field_util::{FieldExt, SchemaExt}; use datatypes::for_all_primitive_types; use datatypes::prelude::*; -use datatypes::types::PrimitiveElement; -use function::{create_query_engine, get_numbers_from_table}; +use datatypes::types::WrapperType; use query::error::Result; use query::QueryEngine; use session::context::QueryContext; @@ -30,7 +29,7 @@ use session::context::QueryContext; #[tokio::test] async fn test_argmin_aggregator() -> Result<()> { common_telemetry::init_default_ut_logging(); - let engine = create_query_engine(); + let engine = function::create_query_engine(); macro_rules! test_argmin { ([], $( { $T:ty } ),*) => { @@ -50,33 +49,23 @@ async fn test_argmin_success( engine: Arc, ) -> Result<()> where - T: PrimitiveElement + PartialOrd, - for<'a> T: Scalar = T>, + T: WrapperType + PartialOrd, { let result = execute_argmin(column_name, table_name, engine.clone()) .await .unwrap(); - assert_eq!(1, result.len()); - assert_eq!(result[0].df_recordbatch.num_columns(), 1); - assert_eq!(1, result[0].schema.arrow_schema().fields().len()); - assert_eq!("argmin", result[0].schema.arrow_schema().field(0).name()); - - let columns = result[0].df_recordbatch.columns(); - assert_eq!(1, columns.len()); - assert_eq!(columns[0].len(), 1); - let v = VectorHelper::try_into_vector(&columns[0]).unwrap(); - assert_eq!(1, v.len()); - let value = v.get(0); + let value = function::get_value_from_batches("argmin", result); - let numbers = get_numbers_from_table::(column_name, table_name, engine.clone()).await; + let numbers = + function::get_numbers_from_table::(column_name, table_name, engine.clone()).await; let expected_value = match numbers.len() { 0 => 0_u32, _ => { let mut index = 0; - let mut min = numbers[0].into(); + let mut min = numbers[0]; for (i, &number) in numbers.iter().enumerate() { - if min > number.into() { - min = number.into(); + if min > number { + min = number; index = i; } } diff --git a/src/query/tests/function.rs b/src/query/tests/function.rs index bcc26a8dec82..7de93a6265ec 100644 --- a/src/query/tests/function.rs +++ b/src/query/tests/function.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +// FIXME(yingwen): Consider move all tests under query/tests to query/src so we could reuse +// more codes. use std::sync::Arc; use catalog::local::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaProvider}; @@ -23,7 +25,7 @@ use datatypes::for_all_primitive_types; use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; use datatypes::types::WrapperType; -use datatypes::vectors::PrimitiveVector; +use datatypes::vectors::Helper; use query::query_engine::QueryEngineFactory; use query::QueryEngine; use rand::Rng; @@ -47,7 +49,7 @@ pub fn create_query_engine() -> Arc { column_schemas.push(column_schema); let numbers = (1..=10).map(|_| rng.gen::<$T>()).collect::>(); - let column: VectorRef = Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec())); + let column: VectorRef = Arc::new(<$T as Scalar>::VectorType::from_vec(numbers.to_vec())); columns.push(column); )* } @@ -92,6 +94,20 @@ where let numbers = util::collect(recordbatch_stream).await.unwrap(); let column = numbers[0].column(0); - let column: &::VectorType = unsafe { VectorHelper::static_cast(column) }; + let column: &::VectorType = unsafe { Helper::static_cast(column) }; column.iter_data().flatten().collect::>() } + +pub fn get_value_from_batches(column_name: &str, batches: Vec) -> Value { + assert_eq!(1, batches.len()); + assert_eq!(batches[0].num_columns(), 1); + assert_eq!(1, batches[0].schema.num_columns()); + assert_eq!(column_name, batches[0].schema.column_schemas()[0].name); + + let batch = &batches[0]; + assert_eq!(1, batch.num_columns()); + assert_eq!(batch.column(0).len(), 1); + let v = batch.column(0); + assert_eq!(1, v.len()); + v.get(0) +} diff --git a/src/query/tests/mean_test.rs b/src/query/tests/mean_test.rs index 705dea797db1..000323fb2192 100644 --- a/src/query/tests/mean_test.rs +++ b/src/query/tests/mean_test.rs @@ -12,19 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; mod function; +use std::sync::Arc; + use common_query::Output; use common_recordbatch::error::Result as RecordResult; use common_recordbatch::{util, RecordBatch}; -use datafusion::field_util::{FieldExt, SchemaExt}; use datatypes::for_all_primitive_types; use datatypes::prelude::*; -use datatypes::types::PrimitiveElement; +use datatypes::types::WrapperType; use datatypes::value::OrderedFloat; use format_num::NumberFormat; -use function::{create_query_engine, get_numbers_from_table}; use num_traits::AsPrimitive; use query::error::Result; use query::QueryEngine; @@ -33,7 +32,7 @@ use session::context::QueryContext; #[tokio::test] async fn test_mean_aggregator() -> Result<()> { common_telemetry::init_default_ut_logging(); - let engine = create_query_engine(); + let engine = function::create_query_engine(); macro_rules! test_mean { ([], $( { $T:ty } ),*) => { @@ -53,25 +52,15 @@ async fn test_mean_success( engine: Arc, ) -> Result<()> where - T: PrimitiveElement + AsPrimitive, - for<'a> T: Scalar = T>, + T: WrapperType + AsPrimitive, { let result = execute_mean(column_name, table_name, engine.clone()) .await .unwrap(); - assert_eq!(1, result.len()); - assert_eq!(result[0].df_recordbatch.num_columns(), 1); - assert_eq!(1, result[0].schema.arrow_schema().fields().len()); - assert_eq!("mean", result[0].schema.arrow_schema().field(0).name()); - - let columns = result[0].df_recordbatch.columns(); - assert_eq!(1, columns.len()); - assert_eq!(columns[0].len(), 1); - let v = VectorHelper::try_into_vector(&columns[0]).unwrap(); - assert_eq!(1, v.len()); - let value = v.get(0); + let value = function::get_value_from_batches("mean", result); - let numbers = get_numbers_from_table::(column_name, table_name, engine.clone()).await; + let numbers = + function::get_numbers_from_table::(column_name, table_name, engine.clone()).await; let expected_value = numbers.iter().map(|&n| n.as_()).collect::>(); let expected_value = inc_stats::mean(expected_value.iter().cloned()).unwrap(); diff --git a/src/query/tests/my_sum_udaf_example.rs b/src/query/tests/my_sum_udaf_example.rs index 4e05183861ee..6f132a95b2b0 100644 --- a/src/query/tests/my_sum_udaf_example.rs +++ b/src/query/tests/my_sum_udaf_example.rs @@ -26,12 +26,10 @@ use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; use common_query::Output; use common_recordbatch::{util, RecordBatch}; -use datafusion::arrow_print; -use datafusion_common::record_batch::RecordBatch as DfRecordBatch; use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; -use datatypes::types::{PrimitiveElement, PrimitiveType}; -use datatypes::vectors::PrimitiveVector; +use datatypes::types::{LogicalPrimitiveType, WrapperType}; +use datatypes::vectors::Helper; use datatypes::with_match_primitive_type_id; use num_traits::AsPrimitive; use query::error::Result; @@ -40,28 +38,30 @@ use session::context::QueryContext; use table::test_util::MemTable; #[derive(Debug, Default)] -struct MySumAccumulator -where - T: Primitive + AsPrimitive, - SumT: Primitive + std::ops::AddAssign, -{ +struct MySumAccumulator { sum: SumT, _phantom: PhantomData, } impl MySumAccumulator where - T: Primitive + AsPrimitive, - SumT: Primitive + std::ops::AddAssign, + T: WrapperType, + SumT: WrapperType, + T::Native: AsPrimitive, + SumT::Native: std::ops::AddAssign, { #[inline(always)] fn add(&mut self, v: T) { - self.sum += v.as_(); + let mut sum_native = self.sum.into_native(); + sum_native += v.into_native().as_(); + self.sum = SumT::from_native(sum_native); } #[inline(always)] fn merge(&mut self, s: SumT) { - self.sum += s; + let mut sum_native = self.sum.into_native(); + sum_native += s.into_native(); + self.sum = SumT::from_native(sum_native); } } @@ -76,7 +76,7 @@ impl AggregateFunctionCreator for MySumAccumulatorCreator { with_match_primitive_type_id!( input_type.logical_type_id(), |$S| { - Ok(Box::new(MySumAccumulator::<$S, <$S as Primitive>::LargestType>::default())) + Ok(Box::new(MySumAccumulator::<<$S as LogicalPrimitiveType>::Wrapper, <<$S as LogicalPrimitiveType>::LargestType as LogicalPrimitiveType>::Wrapper>::default())) }, { let err_msg = format!( @@ -95,7 +95,7 @@ impl AggregateFunctionCreator for MySumAccumulatorCreator { with_match_primitive_type_id!( input_type.logical_type_id(), |$S| { - Ok(PrimitiveType::<<$S as Primitive>::LargestType>::default().logical_type_id().data_type()) + Ok(<<$S as LogicalPrimitiveType>::LargestType>::build_data_type()) }, { unreachable!() @@ -110,10 +110,10 @@ impl AggregateFunctionCreator for MySumAccumulatorCreator { impl Accumulator for MySumAccumulator where - T: Primitive + AsPrimitive, - for<'a> T: Scalar = T>, - SumT: Primitive + std::ops::AddAssign, - for<'a> SumT: Scalar = SumT>, + T: WrapperType, + SumT: WrapperType, + T::Native: AsPrimitive, + SumT::Native: std::ops::AddAssign, { fn state(&self) -> QueryResult> { Ok(vec![self.sum.into()]) @@ -124,7 +124,7 @@ where return Ok(()); }; let column = &values[0]; - let column: &::VectorType = unsafe { VectorHelper::static_cast(column) }; + let column: &::VectorType = unsafe { Helper::static_cast(column) }; for v in column.iter_data().flatten() { self.add(v) } @@ -136,7 +136,7 @@ where return Ok(()); }; let states = &states[0]; - let states: &::VectorType = unsafe { VectorHelper::static_cast(states) }; + let states: &::VectorType = unsafe { Helper::static_cast(states) }; for s in states.iter_data().flatten() { self.merge(s) } @@ -201,18 +201,18 @@ async fn test_my_sum() -> Result<()> { async fn test_my_sum_with(numbers: Vec, expected: Vec<&str>) -> Result<()> where - T: PrimitiveElement, + T: WrapperType, { let table_name = format!("{}_numbers", std::any::type_name::()); let column_name = format!("{}_number", std::any::type_name::()); let column_schemas = vec![ColumnSchema::new( column_name.clone(), - T::build_data_type(), + T::LogicalType::build_data_type(), true, )]; let schema = Arc::new(Schema::new(column_schemas.clone())); - let column: VectorRef = Arc::new(PrimitiveVector::::from_vec(numbers)); + let column: VectorRef = Arc::new(T::VectorType::from_vec(numbers)); let recordbatch = RecordBatch::new(schema, vec![column]).unwrap(); let testing_table = MemTable::new(&table_name, recordbatch); @@ -236,15 +236,11 @@ where Output::Stream(batch) => batch, _ => unreachable!(), }; - let recordbatch = util::collect(recordbatch_stream).await.unwrap(); - let df_recordbatch = recordbatch - .into_iter() - .map(|r| r.df_recordbatch) - .collect::>(); + let batches = util::collect_batches(recordbatch_stream).await.unwrap(); - let pretty_print = arrow_print::write(&df_recordbatch); - let pretty_print = pretty_print.lines().collect::>(); - assert_eq!(expected, pretty_print); + let pretty_print = batches.pretty_print().unwrap(); + // TODO(yingwen): Check pretty print output. + assert_eq!(expected, vec![pretty_print]); Ok(()) } diff --git a/src/query/tests/percentile_test.rs b/src/query/tests/percentile_test.rs index 6e210a0494e0..d41a55d44b82 100644 --- a/src/query/tests/percentile_test.rs +++ b/src/query/tests/percentile_test.rs @@ -20,12 +20,10 @@ use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_query::Output; use common_recordbatch::error::Result as RecordResult; use common_recordbatch::{util, RecordBatch}; -use datafusion::field_util::{FieldExt, SchemaExt}; use datatypes::for_all_primitive_types; use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; -use datatypes::types::PrimitiveElement; -use datatypes::vectors::PrimitiveVector; +use datatypes::vectors::Int32Vector; use function::{create_query_engine, get_numbers_from_table}; use num_traits::AsPrimitive; use query::error::Result; @@ -64,9 +62,8 @@ async fn test_percentile_correctness() -> Result<()> { _ => unreachable!(), }; let record_batch = util::collect(recordbatch_stream).await.unwrap(); - let columns = record_batch[0].df_recordbatch.columns(); - let v = VectorHelper::try_into_vector(&columns[0]).unwrap(); - let value = v.get(0); + let column = record_batch[0].column(0); + let value = column.get(0); assert_eq!(value, Value::from(9.280_000_000_000_001_f64)); Ok(()) } @@ -77,26 +74,12 @@ async fn test_percentile_success( engine: Arc, ) -> Result<()> where - T: PrimitiveElement + AsPrimitive, - for<'a> T: Scalar = T>, + T: WrapperType + AsPrimitive, { let result = execute_percentile(column_name, table_name, engine.clone()) .await .unwrap(); - assert_eq!(1, result.len()); - assert_eq!(result[0].df_recordbatch.num_columns(), 1); - assert_eq!(1, result[0].schema.arrow_schema().fields().len()); - assert_eq!( - "percentile", - result[0].schema.arrow_schema().field(0).name() - ); - - let columns = result[0].df_recordbatch.columns(); - assert_eq!(1, columns.len()); - assert_eq!(columns[0].len(), 1); - let v = VectorHelper::try_into_vector(&columns[0]).unwrap(); - assert_eq!(1, v.len()); - let value = v.get(0); + let value = function::get_value_from_batches("percentile", result); let numbers = get_numbers_from_table::(column_name, table_name, engine.clone()).await; let expected_value = numbers.iter().map(|&n| n.as_()).collect::>(); @@ -142,7 +125,7 @@ fn create_correctness_engine() -> Arc { let numbers = vec![3_i32, 6_i32, 8_i32, 10_i32]; - let column: VectorRef = Arc::new(PrimitiveVector::::from_vec(numbers.to_vec())); + let column: VectorRef = Arc::new(Int32Vector::from_vec(numbers.to_vec())); columns.push(column); let schema = Arc::new(Schema::new(column_schemas)); diff --git a/src/query/tests/polyval_test.rs b/src/query/tests/polyval_test.rs index f2e60c0217ca..248c0d42d74e 100644 --- a/src/query/tests/polyval_test.rs +++ b/src/query/tests/polyval_test.rs @@ -18,11 +18,9 @@ mod function; use common_query::Output; use common_recordbatch::error::Result as RecordResult; use common_recordbatch::{util, RecordBatch}; -use datafusion::field_util::{FieldExt, SchemaExt}; use datatypes::for_all_primitive_types; use datatypes::prelude::*; -use datatypes::types::PrimitiveElement; -use function::{create_query_engine, get_numbers_from_table}; +use datatypes::types::WrapperType; use num_traits::AsPrimitive; use query::error::Result; use query::QueryEngine; @@ -31,13 +29,13 @@ use session::context::QueryContext; #[tokio::test] async fn test_polyval_aggregator() -> Result<()> { common_telemetry::init_default_ut_logging(); - let engine = create_query_engine(); + let engine = function::create_query_engine(); macro_rules! test_polyval { ([], $( { $T:ty } ),*) => { $( let column_name = format!("{}_number", std::any::type_name::<$T>()); - test_polyval_success::<$T,<$T as Primitive>::LargestType>(&column_name, "numbers", engine.clone()).await?; + test_polyval_success::<$T, <<<$T as WrapperType>::LogicalType as LogicalPrimitiveType>::LargestType as LogicalPrimitiveType>::Wrapper>(&column_name, "numbers", engine.clone()).await?; )* } } @@ -51,36 +49,27 @@ async fn test_polyval_success( engine: Arc, ) -> Result<()> where - T: Primitive + AsPrimitive + PrimitiveElement, - PolyT: Primitive + std::ops::Mul + std::iter::Sum, - for<'a> T: Scalar = T>, - for<'a> PolyT: Scalar = PolyT>, - i64: AsPrimitive, + T: WrapperType, + PolyT: WrapperType, + T::Native: AsPrimitive, + PolyT::Native: std::ops::Mul + std::iter::Sum, + i64: AsPrimitive, { let result = execute_polyval(column_name, table_name, engine.clone()) .await .unwrap(); - assert_eq!(1, result.len()); - assert_eq!(result[0].df_recordbatch.num_columns(), 1); - assert_eq!(1, result[0].schema.arrow_schema().fields().len()); - assert_eq!("polyval", result[0].schema.arrow_schema().field(0).name()); + let value = function::get_value_from_batches("polyval", result); - let columns = result[0].df_recordbatch.columns(); - assert_eq!(1, columns.len()); - assert_eq!(columns[0].len(), 1); - let v = VectorHelper::try_into_vector(&columns[0]).unwrap(); - assert_eq!(1, v.len()); - let value = v.get(0); - - let numbers = get_numbers_from_table::(column_name, table_name, engine.clone()).await; + let numbers = + function::get_numbers_from_table::(column_name, table_name, engine.clone()).await; let expected_value = numbers.iter().copied(); let x = 0i64; let len = expected_value.len(); - let expected_value: PolyT = expected_value + let expected_native: PolyT::Native = expected_value .enumerate() - .map(|(i, value)| value.as_() * (x.pow((len - 1 - i) as u32)).as_()) + .map(|(i, v)| v.into_native().as_() * (x.pow((len - 1 - i) as u32)).as_()) .sum(); - assert_eq!(value, expected_value.into()); + assert_eq!(value, PolyT::from_native(expected_native).into()); Ok(()) } diff --git a/src/query/tests/query_engine_test.rs b/src/query/tests/query_engine_test.rs index cf640afba48e..c059a4d0d888 100644 --- a/src/query/tests/query_engine_test.rs +++ b/src/query/tests/query_engine_test.rs @@ -13,6 +13,11 @@ // limitations under the License. mod pow; +// This is used to suppress the warning: function `create_query_engine` is never used. +// FIXME(yingwen): We finally need to refactor these tests and move them to `query/src` +// so tests can share codes with other mods. +#[allow(unused)] +mod function; use std::sync::Arc; @@ -23,14 +28,13 @@ use common_query::prelude::{create_udf, make_scalar_function, Volatility}; use common_query::Output; use common_recordbatch::error::Result as RecordResult; use common_recordbatch::{util, RecordBatch}; -use datafusion::field_util::{FieldExt, SchemaExt}; -use datafusion::logical_plan::LogicalPlanBuilder; -use datatypes::arrow::array::UInt32Array; +use datafusion::datasource::DefaultTableSource; +use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; use datatypes::for_all_primitive_types; use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; -use datatypes::types::{OrdPrimitive, PrimitiveElement}; -use datatypes::vectors::{PrimitiveVector, UInt32Vector}; +use datatypes::types::{OrdPrimitive, WrapperType}; +use datatypes::vectors::UInt32Vector; use num::NumCast; use query::error::Result; use query::plan::LogicalPlan; @@ -66,12 +70,16 @@ async fn test_datafusion_query_engine() -> Result<()> { let limit = 10; let table_provider = Arc::new(DfTableProviderAdapter::new(table.clone())); let plan = LogicalPlan::DfPlan( - LogicalPlanBuilder::scan("numbers", table_provider, None) - .unwrap() - .limit(limit) - .unwrap() - .build() - .unwrap(), + LogicalPlanBuilder::scan( + "numbers", + Arc::new(DefaultTableSource { table_provider }), + None, + ) + .unwrap() + .limit(0, Some(limit)) + .unwrap() + .build() + .unwrap(), ); let output = engine.execute(&plan).await?; @@ -84,17 +92,17 @@ async fn test_datafusion_query_engine() -> Result<()> { let numbers = util::collect(recordbatch).await.unwrap(); assert_eq!(1, numbers.len()); - assert_eq!(numbers[0].df_recordbatch.num_columns(), 1); - assert_eq!(1, numbers[0].schema.arrow_schema().fields().len()); - assert_eq!("number", numbers[0].schema.arrow_schema().field(0).name()); + assert_eq!(numbers[0].num_columns(), 1); + assert_eq!(1, numbers[0].schema.num_columns()); + assert_eq!("number", numbers[0].schema.column_schemas()[0].name); - let columns = numbers[0].df_recordbatch.columns(); - assert_eq!(1, columns.len()); - assert_eq!(columns[0].len(), limit); + let batch = &numbers[0]; + assert_eq!(1, batch.num_columns()); + assert_eq!(batch.column(0).len(), limit); let expected: Vec = (0u32..limit as u32).collect(); assert_eq!( - *columns[0].as_any().downcast_ref::().unwrap(), - UInt32Array::from_slice(&expected) + *batch.column(0), + Arc::new(UInt32Vector::from_slice(&expected)) as VectorRef ); Ok(()) @@ -148,17 +156,17 @@ async fn test_udf() -> Result<()> { let numbers = util::collect(recordbatch).await.unwrap(); assert_eq!(1, numbers.len()); - assert_eq!(numbers[0].df_recordbatch.num_columns(), 1); - assert_eq!(1, numbers[0].schema.arrow_schema().fields().len()); - assert_eq!("p", numbers[0].schema.arrow_schema().field(0).name()); + assert_eq!(numbers[0].num_columns(), 1); + assert_eq!(1, numbers[0].schema.num_columns()); + assert_eq!("p", numbers[0].schema.column_schemas()[0].name); - let columns = numbers[0].df_recordbatch.columns(); - assert_eq!(1, columns.len()); - assert_eq!(columns[0].len(), 10); + let batch = &numbers[0]; + assert_eq!(1, batch.num_columns()); + assert_eq!(batch.column(0).len(), 10); let expected: Vec = vec![1, 1, 4, 27, 256, 3125, 46656, 823543, 16777216, 387420489]; assert_eq!( - *columns[0].as_any().downcast_ref::().unwrap(), - UInt32Array::from_slice(&expected) + *batch.column(0), + Arc::new(UInt32Vector::from_slice(&expected)) as VectorRef ); Ok(()) @@ -182,7 +190,7 @@ fn create_query_engine() -> Arc { column_schemas.push(column_schema); let numbers = (1..=100).map(|_| rng.gen::<$T>()).collect::>(); - let column: VectorRef = Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec())); + let column: VectorRef = Arc::new(<$T as Scalar>::VectorType::from_vec(numbers.to_vec())); columns.push(column); )* } @@ -212,7 +220,7 @@ fn create_query_engine() -> Arc { column_schemas.push(column_schema); let numbers = (1..=99).map(|_| rng.gen::<$T>()).collect::>(); - let column: VectorRef = Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec())); + let column: VectorRef = Arc::new(<$T as Scalar>::VectorType::from_vec(numbers.to_vec())); columns.push(column); )* } @@ -236,37 +244,6 @@ fn create_query_engine() -> Arc { QueryEngineFactory::new(catalog_list).query_engine() } -async fn get_numbers_from_table<'s, T>( - column_name: &'s str, - table_name: &'s str, - engine: Arc, -) -> Vec> -where - T: PrimitiveElement, - for<'a> T: Scalar = T>, -{ - let sql = format!("SELECT {} FROM {}", column_name, table_name); - let plan = engine - .sql_to_plan(&sql, Arc::new(QueryContext::new())) - .unwrap(); - - let output = engine.execute(&plan).await.unwrap(); - let recordbatch_stream = match output { - Output::Stream(batch) => batch, - _ => unreachable!(), - }; - let numbers = util::collect(recordbatch_stream).await.unwrap(); - - let columns = numbers[0].df_recordbatch.columns(); - let column = VectorHelper::try_into_vector(&columns[0]).unwrap(); - let column: &::VectorType = unsafe { VectorHelper::static_cast(&column) }; - column - .iter_data() - .flatten() - .map(|x| OrdPrimitive::(x)) - .collect::>>() -} - #[tokio::test] async fn test_median_aggregator() -> Result<()> { common_telemetry::init_default_ut_logging(); @@ -294,25 +271,17 @@ async fn test_median_success( engine: Arc, ) -> Result<()> where - T: PrimitiveElement, - for<'a> T: Scalar = T>, + T: WrapperType, + T::Native: NumCast, { let result = execute_median(column_name, table_name, engine.clone()) .await .unwrap(); - assert_eq!(1, result.len()); - assert_eq!(result[0].df_recordbatch.num_columns(), 1); - assert_eq!(1, result[0].schema.arrow_schema().fields().len()); - assert_eq!("median", result[0].schema.arrow_schema().field(0).name()); - - let columns = result[0].df_recordbatch.columns(); - assert_eq!(1, columns.len()); - assert_eq!(columns[0].len(), 1); - let v = VectorHelper::try_into_vector(&columns[0]).unwrap(); - assert_eq!(1, v.len()); - let median = v.get(0); - - let mut numbers = get_numbers_from_table::(column_name, table_name, engine.clone()).await; + let median = function::get_value_from_batches("median", result); + + let numbers = + function::get_numbers_from_table::(column_name, table_name, engine.clone()).await; + let mut numbers: Vec<_> = numbers.into_iter().map(|v| OrdPrimitive(v)).collect(); numbers.sort(); let len = numbers.len(); let expected_median: Value = if len % 2 == 1 { @@ -320,7 +289,7 @@ where } else { let a: f64 = NumCast::from(numbers[len / 2 - 1].as_primitive()).unwrap(); let b: f64 = NumCast::from(numbers[len / 2].as_primitive()).unwrap(); - OrdPrimitive::(NumCast::from(a / 2.0 + b / 2.0).unwrap()) + OrdPrimitive::(T::from_native(NumCast::from(a / 2.0 + b / 2.0).unwrap())) } .into(); assert_eq!(expected_median, median); diff --git a/src/query/tests/scipy_stats_norm_cdf_test.rs b/src/query/tests/scipy_stats_norm_cdf_test.rs index 815501a314cb..dee8f5c87ee3 100644 --- a/src/query/tests/scipy_stats_norm_cdf_test.rs +++ b/src/query/tests/scipy_stats_norm_cdf_test.rs @@ -18,11 +18,8 @@ mod function; use common_query::Output; use common_recordbatch::error::Result as RecordResult; use common_recordbatch::{util, RecordBatch}; -use datafusion::field_util::{FieldExt, SchemaExt}; use datatypes::for_all_primitive_types; -use datatypes::prelude::*; -use datatypes::types::PrimitiveElement; -use function::{create_query_engine, get_numbers_from_table}; +use datatypes::types::WrapperType; use num_traits::AsPrimitive; use query::error::Result; use query::QueryEngine; @@ -33,7 +30,7 @@ use statrs::statistics::Statistics; #[tokio::test] async fn test_scipy_stats_norm_cdf_aggregator() -> Result<()> { common_telemetry::init_default_ut_logging(); - let engine = create_query_engine(); + let engine = function::create_query_engine(); macro_rules! test_scipy_stats_norm_cdf { ([], $( { $T:ty } ),*) => { @@ -53,28 +50,15 @@ async fn test_scipy_stats_norm_cdf_success( engine: Arc, ) -> Result<()> where - T: PrimitiveElement + AsPrimitive, - for<'a> T: Scalar = T>, + T: WrapperType + AsPrimitive, { let result = execute_scipy_stats_norm_cdf(column_name, table_name, engine.clone()) .await .unwrap(); - assert_eq!(1, result.len()); - assert_eq!(result[0].df_recordbatch.num_columns(), 1); - assert_eq!(1, result[0].schema.arrow_schema().fields().len()); - assert_eq!( - "scipy_stats_norm_cdf", - result[0].schema.arrow_schema().field(0).name() - ); - - let columns = result[0].df_recordbatch.columns(); - assert_eq!(1, columns.len()); - assert_eq!(columns[0].len(), 1); - let v = VectorHelper::try_into_vector(&columns[0]).unwrap(); - assert_eq!(1, v.len()); - let value = v.get(0); + let value = function::get_value_from_batches("scipy_stats_norm_cdf", result); - let numbers = get_numbers_from_table::(column_name, table_name, engine.clone()).await; + let numbers = + function::get_numbers_from_table::(column_name, table_name, engine.clone()).await; let expected_value = numbers.iter().map(|&n| n.as_()).collect::>(); let mean = expected_value.clone().mean(); let stddev = expected_value.std_dev(); diff --git a/src/query/tests/scipy_stats_norm_pdf.rs b/src/query/tests/scipy_stats_norm_pdf.rs index dd5e0fc7fc5b..03e4cf129220 100644 --- a/src/query/tests/scipy_stats_norm_pdf.rs +++ b/src/query/tests/scipy_stats_norm_pdf.rs @@ -18,11 +18,8 @@ mod function; use common_query::Output; use common_recordbatch::error::Result as RecordResult; use common_recordbatch::{util, RecordBatch}; -use datafusion::field_util::{FieldExt, SchemaExt}; use datatypes::for_all_primitive_types; -use datatypes::prelude::*; -use datatypes::types::PrimitiveElement; -use function::{create_query_engine, get_numbers_from_table}; +use datatypes::types::WrapperType; use num_traits::AsPrimitive; use query::error::Result; use query::QueryEngine; @@ -33,7 +30,7 @@ use statrs::statistics::Statistics; #[tokio::test] async fn test_scipy_stats_norm_pdf_aggregator() -> Result<()> { common_telemetry::init_default_ut_logging(); - let engine = create_query_engine(); + let engine = function::create_query_engine(); macro_rules! test_scipy_stats_norm_pdf { ([], $( { $T:ty } ),*) => { @@ -53,28 +50,15 @@ async fn test_scipy_stats_norm_pdf_success( engine: Arc, ) -> Result<()> where - T: PrimitiveElement + AsPrimitive, - for<'a> T: Scalar = T>, + T: WrapperType + AsPrimitive, { let result = execute_scipy_stats_norm_pdf(column_name, table_name, engine.clone()) .await .unwrap(); - assert_eq!(1, result.len()); - assert_eq!(result[0].df_recordbatch.num_columns(), 1); - assert_eq!(1, result[0].schema.arrow_schema().fields().len()); - assert_eq!( - "scipy_stats_norm_pdf", - result[0].schema.arrow_schema().field(0).name() - ); - - let columns = result[0].df_recordbatch.columns(); - assert_eq!(1, columns.len()); - assert_eq!(columns[0].len(), 1); - let v = VectorHelper::try_into_vector(&columns[0]).unwrap(); - assert_eq!(1, v.len()); - let value = v.get(0); + let value = function::get_value_from_batches("scipy_stats_norm_pdf", result); - let numbers = get_numbers_from_table::(column_name, table_name, engine.clone()).await; + let numbers = + function::get_numbers_from_table::(column_name, table_name, engine.clone()).await; let expected_value = numbers.iter().map(|&n| n.as_()).collect::>(); let mean = expected_value.clone().mean(); let stddev = expected_value.std_dev();