Skip to content

Commit

Permalink
fix: Fix all warnings in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
evenyag committed Dec 13, 2022
1 parent 6c1a558 commit 38b8921
Show file tree
Hide file tree
Showing 13 changed files with 180 additions and 269 deletions.
2 changes: 1 addition & 1 deletion src/common/recordbatch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl Stream for EmptyRecordBatchStream {
}
}

#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub struct RecordBatches {
schema: SchemaRef,
batches: Vec<RecordBatch>,
Expand Down
8 changes: 8 additions & 0 deletions src/common/recordbatch/src/recordbatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use datatypes::arrow::util::pretty;
use datatypes::schema::SchemaRef;
use datatypes::value::Value;
use datatypes::vectors::{Helper, VectorRef};
Expand Down Expand Up @@ -98,6 +99,13 @@ impl RecordBatch {
pub fn rows(&self) -> RecordBatchRowIterator<'_> {
RecordBatchRowIterator::new(self)
}

pub fn pretty_print(&self) -> Result<String> {
let df_batch = self.df_record_batch.clone();
let result = pretty::pretty_format_batches(&[df_batch]).context(error::FormatSnafu)?;

Ok(result.to_string())
}
}

impl Serialize for RecordBatch {
Expand Down
18 changes: 16 additions & 2 deletions src/common/recordbatch/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,20 @@
use futures::TryStreamExt;

use crate::error::Result;
use crate::{RecordBatch, SendableRecordBatchStream};
use crate::{RecordBatch, RecordBatches, SendableRecordBatchStream};

/// Collect all the items from the stream into a vector of [`RecordBatch`].
pub async fn collect(stream: SendableRecordBatchStream) -> Result<Vec<RecordBatch>> {
stream.try_collect::<Vec<_>>().await
}

/// Collect all the items from the stream into [RecordBatches].
pub async fn collect_batches(stream: SendableRecordBatchStream) -> Result<RecordBatches> {
let schema = stream.schema();
let batches = stream.try_collect::<Vec<_>>().await?;
RecordBatches::try_new(schema, batches)
}

#[cfg(test)]
mod tests {
use std::mem;
Expand Down Expand Up @@ -90,7 +97,14 @@ mod tests {
};
let batches = collect(Box::pin(stream)).await.unwrap();
assert_eq!(1, batches.len());

assert_eq!(batch, batches[0]);

let stream = MockRecordBatchStream {
schema: schema.clone(),
batch: Some(batch.clone()),
};
let batches = collect_batches(Box::pin(stream)).await.unwrap();
let expect_batches = RecordBatches::try_new(schema.clone(), vec![batch]).unwrap();
assert_eq!(expect_batches, batches);
}
}
34 changes: 12 additions & 22 deletions src/query/tests/argmax_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;
mod function;

use std::sync::Arc;

use common_query::Output;
use common_recordbatch::error::Result as RecordResult;
use common_recordbatch::{util, RecordBatch};
use datafusion::field_util::{FieldExt, SchemaExt};
use datatypes::for_all_primitive_types;
use datatypes::prelude::*;
use datatypes::types::PrimitiveElement;
use function::{create_query_engine, get_numbers_from_table};
use datatypes::types::WrapperType;
use query::error::Result;
use query::QueryEngine;
use session::context::QueryContext;

#[tokio::test]
async fn test_argmax_aggregator() -> Result<()> {
common_telemetry::init_default_ut_logging();
let engine = create_query_engine();
let engine = function::create_query_engine();

macro_rules! test_argmax {
([], $( { $T:ty } ),*) => {
Expand All @@ -49,33 +49,23 @@ async fn test_argmax_success<T>(
engine: Arc<dyn QueryEngine>,
) -> Result<()>
where
T: PrimitiveElement + PartialOrd,
for<'a> T: Scalar<RefType<'a> = T>,
T: WrapperType + PartialOrd,
{
let result = execute_argmax(column_name, table_name, engine.clone())
.await
.unwrap();
assert_eq!(1, result.len());
assert_eq!(result[0].df_recordbatch.num_columns(), 1);
assert_eq!(1, result[0].schema.arrow_schema().fields().len());
assert_eq!("argmax", result[0].schema.arrow_schema().field(0).name());

let columns = result[0].df_recordbatch.columns();
assert_eq!(1, columns.len());
assert_eq!(columns[0].len(), 1);
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
assert_eq!(1, v.len());
let value = v.get(0);
let value = function::get_value_from_batches("argmax", result);

let numbers = get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
let numbers =
function::get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
let expected_value = match numbers.len() {
0 => 0_u64,
_ => {
let mut index = 0;
let mut max = numbers[0].into();
let mut max = numbers[0];
for (i, &number) in numbers.iter().enumerate() {
if max < number.into() {
max = number.into();
if max < number {
max = number;
index = i;
}
}
Expand Down
33 changes: 11 additions & 22 deletions src/query/tests/argmin_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;
mod function;

use std::sync::Arc;

use common_query::Output;
use common_recordbatch::error::Result as RecordResult;
use common_recordbatch::{util, RecordBatch};
use datafusion::field_util::{FieldExt, SchemaExt};
use datatypes::for_all_primitive_types;
use datatypes::prelude::*;
use datatypes::types::PrimitiveElement;
use function::{create_query_engine, get_numbers_from_table};
use datatypes::types::WrapperType;
use query::error::Result;
use query::QueryEngine;
use session::context::QueryContext;

#[tokio::test]
async fn test_argmin_aggregator() -> Result<()> {
common_telemetry::init_default_ut_logging();
let engine = create_query_engine();
let engine = function::create_query_engine();

macro_rules! test_argmin {
([], $( { $T:ty } ),*) => {
Expand All @@ -50,33 +49,23 @@ async fn test_argmin_success<T>(
engine: Arc<dyn QueryEngine>,
) -> Result<()>
where
T: PrimitiveElement + PartialOrd,
for<'a> T: Scalar<RefType<'a> = T>,
T: WrapperType + PartialOrd,
{
let result = execute_argmin(column_name, table_name, engine.clone())
.await
.unwrap();
assert_eq!(1, result.len());
assert_eq!(result[0].df_recordbatch.num_columns(), 1);
assert_eq!(1, result[0].schema.arrow_schema().fields().len());
assert_eq!("argmin", result[0].schema.arrow_schema().field(0).name());

let columns = result[0].df_recordbatch.columns();
assert_eq!(1, columns.len());
assert_eq!(columns[0].len(), 1);
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
assert_eq!(1, v.len());
let value = v.get(0);
let value = function::get_value_from_batches("argmin", result);

let numbers = get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
let numbers =
function::get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
let expected_value = match numbers.len() {
0 => 0_u32,
_ => {
let mut index = 0;
let mut min = numbers[0].into();
let mut min = numbers[0];
for (i, &number) in numbers.iter().enumerate() {
if min > number.into() {
min = number.into();
if min > number {
min = number;
index = i;
}
}
Expand Down
22 changes: 19 additions & 3 deletions src/query/tests/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

// FIXME(yingwen): Consider move all tests under query/tests to query/src so we could reuse
// more codes.
use std::sync::Arc;

use catalog::local::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaProvider};
Expand All @@ -23,7 +25,7 @@ use datatypes::for_all_primitive_types;
use datatypes::prelude::*;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::types::WrapperType;
use datatypes::vectors::PrimitiveVector;
use datatypes::vectors::Helper;
use query::query_engine::QueryEngineFactory;
use query::QueryEngine;
use rand::Rng;
Expand All @@ -47,7 +49,7 @@ pub fn create_query_engine() -> Arc<dyn QueryEngine> {
column_schemas.push(column_schema);

let numbers = (1..=10).map(|_| rng.gen::<$T>()).collect::<Vec<$T>>();
let column: VectorRef = Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec()));
let column: VectorRef = Arc::new(<$T as Scalar>::VectorType::from_vec(numbers.to_vec()));
columns.push(column);
)*
}
Expand Down Expand Up @@ -92,6 +94,20 @@ where
let numbers = util::collect(recordbatch_stream).await.unwrap();

let column = numbers[0].column(0);
let column: &<T as Scalar>::VectorType = unsafe { VectorHelper::static_cast(column) };
let column: &<T as Scalar>::VectorType = unsafe { Helper::static_cast(column) };
column.iter_data().flatten().collect::<Vec<T>>()
}

pub fn get_value_from_batches(column_name: &str, batches: Vec<RecordBatch>) -> Value {
assert_eq!(1, batches.len());
assert_eq!(batches[0].num_columns(), 1);
assert_eq!(1, batches[0].schema.num_columns());
assert_eq!(column_name, batches[0].schema.column_schemas()[0].name);

let batch = &batches[0];
assert_eq!(1, batch.num_columns());
assert_eq!(batch.column(0).len(), 1);
let v = batch.column(0);
assert_eq!(1, v.len());
v.get(0)
}
27 changes: 8 additions & 19 deletions src/query/tests/mean_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;
mod function;

use std::sync::Arc;

use common_query::Output;
use common_recordbatch::error::Result as RecordResult;
use common_recordbatch::{util, RecordBatch};
use datafusion::field_util::{FieldExt, SchemaExt};
use datatypes::for_all_primitive_types;
use datatypes::prelude::*;
use datatypes::types::PrimitiveElement;
use datatypes::types::WrapperType;
use datatypes::value::OrderedFloat;
use format_num::NumberFormat;
use function::{create_query_engine, get_numbers_from_table};
use num_traits::AsPrimitive;
use query::error::Result;
use query::QueryEngine;
Expand All @@ -33,7 +32,7 @@ use session::context::QueryContext;
#[tokio::test]
async fn test_mean_aggregator() -> Result<()> {
common_telemetry::init_default_ut_logging();
let engine = create_query_engine();
let engine = function::create_query_engine();

macro_rules! test_mean {
([], $( { $T:ty } ),*) => {
Expand All @@ -53,25 +52,15 @@ async fn test_mean_success<T>(
engine: Arc<dyn QueryEngine>,
) -> Result<()>
where
T: PrimitiveElement + AsPrimitive<f64>,
for<'a> T: Scalar<RefType<'a> = T>,
T: WrapperType + AsPrimitive<f64>,
{
let result = execute_mean(column_name, table_name, engine.clone())
.await
.unwrap();
assert_eq!(1, result.len());
assert_eq!(result[0].df_recordbatch.num_columns(), 1);
assert_eq!(1, result[0].schema.arrow_schema().fields().len());
assert_eq!("mean", result[0].schema.arrow_schema().field(0).name());

let columns = result[0].df_recordbatch.columns();
assert_eq!(1, columns.len());
assert_eq!(columns[0].len(), 1);
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
assert_eq!(1, v.len());
let value = v.get(0);
let value = function::get_value_from_batches("mean", result);

let numbers = get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
let numbers =
function::get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
let expected_value = numbers.iter().map(|&n| n.as_()).collect::<Vec<f64>>();

let expected_value = inc_stats::mean(expected_value.iter().cloned()).unwrap();
Expand Down
Loading

0 comments on commit 38b8921

Please sign in to comment.