diff --git a/src/common/function/src/scalars/numpy/clip.rs b/src/common/function/src/scalars/numpy/clip.rs index 8b3ecb7616db..888a080f3fcf 100644 --- a/src/common/function/src/scalars/numpy/clip.rs +++ b/src/common/function/src/scalars/numpy/clip.rs @@ -17,10 +17,11 @@ use std::sync::Arc; use common_query::error::Result; use common_query::prelude::{Signature, Volatility}; -use datatypes::data_type::{ConcreteDataType, DataType}; +use datatypes::arrow::compute; +use datatypes::arrow::datatypes::ArrowPrimitiveType; +use datatypes::data_type::ConcreteDataType; use datatypes::prelude::*; -use datatypes::with_match_primitive_type_id; -use num_traits::AsPrimitive; +use datatypes::vectors::PrimitiveVector; use paste::paste; use crate::scalars::expression::{scalar_binary_op, EvalContext}; @@ -34,40 +35,32 @@ macro_rules! define_eval { ($O: ident) => { paste! { fn [](columns: &[VectorRef]) -> Result { - with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| { - with_match_primitive_type_id!(columns[1].data_type().logical_type_id(), |$T| { - with_match_primitive_type_id!(columns[2].data_type().logical_type_id(), |$R| { - // clip(a, min, max) is equals to min(max(a, min), max) - let col: VectorRef = Arc::new(scalar_binary_op::< - <$S as LogicalPrimitiveType>::Wrapper, - <$T as LogicalPrimitiveType>::Wrapper, - $O, - _, - >( - &columns[0], - &columns[1], - scalar_max, - &mut EvalContext::default(), - )?); - let col = scalar_binary_op::<$O, <$R as LogicalPrimitiveType>::Wrapper, $O, _>( - &col, - &columns[2], - scalar_min, - &mut EvalContext::default(), - )?; - Ok(Arc::new(col)) - }, { - unreachable!() - }) - }, { - unreachable!() - }) - }, { - unreachable!() - }) + fn cast_vector(input: &VectorRef) -> VectorRef { + Arc::new(PrimitiveVector::<<$O as WrapperType>::LogicalType>::try_from_arrow_array( + compute::cast(&input.to_arrow_array(), &<<<$O as WrapperType>::LogicalType as LogicalPrimitiveType>::ArrowPrimitive as ArrowPrimitiveType>::DATA_TYPE).unwrap() + ).unwrap()) as _ + } + let operator_1 = cast_vector(&columns[0]); + let operator_2 = cast_vector(&columns[1]); + let operator_3 = cast_vector(&columns[2]); + + // clip(a, min, max) is equals to min(max(a, min), max) + let col: VectorRef = Arc::new(scalar_binary_op::<$O, $O, $O, _>( + &operator_1, + &operator_2, + scalar_max, + &mut EvalContext::default(), + )?); + let col = scalar_binary_op::<$O, $O, $O, _>( + &col, + &operator_3, + scalar_min, + &mut EvalContext::default(), + )?; + Ok(Arc::new(col)) } } - } + }; } define_eval!(i64); @@ -123,27 +116,23 @@ pub fn max(input: T, max: T) -> T { } #[inline] -fn scalar_min(left: Option, right: Option, _ctx: &mut EvalContext) -> Option +fn scalar_min(left: Option, right: Option, _ctx: &mut EvalContext) -> Option where - S: AsPrimitive, - T: AsPrimitive, O: Scalar + Copy + PartialOrd, { match (left, right) { - (Some(left), Some(right)) => Some(min(left.as_(), right.as_())), + (Some(left), Some(right)) => Some(min(left, right)), _ => None, } } #[inline] -fn scalar_max(left: Option, right: Option, _ctx: &mut EvalContext) -> Option +fn scalar_max(left: Option, right: Option, _ctx: &mut EvalContext) -> Option where - S: AsPrimitive, - T: AsPrimitive, O: Scalar + Copy + PartialOrd, { match (left, right) { - (Some(left), Some(right)) => Some(max(left.as_(), right.as_())), + (Some(left), Some(right)) => Some(max(left, right)), _ => None, } }