Skip to content

Commit

Permalink
fix: pre-cast to avoid tremendous match arms (#734)
Browse files Browse the repository at this point in the history
Signed-off-by: Ruihang Xia <[email protected]>

Signed-off-by: Ruihang Xia <[email protected]>
  • Loading branch information
waynexia authored Dec 9, 2022
1 parent 42fdc72 commit 95b2d86
Showing 1 changed file with 32 additions and 43 deletions.
75 changes: 32 additions & 43 deletions src/common/function/src/scalars/numpy/clip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ use std::sync::Arc;

use common_query::error::Result;
use common_query::prelude::{Signature, Volatility};
use datatypes::data_type::{ConcreteDataType, DataType};
use datatypes::arrow::compute;
use datatypes::arrow::datatypes::ArrowPrimitiveType;
use datatypes::data_type::ConcreteDataType;
use datatypes::prelude::*;
use datatypes::with_match_primitive_type_id;
use num_traits::AsPrimitive;
use datatypes::vectors::PrimitiveVector;
use paste::paste;

use crate::scalars::expression::{scalar_binary_op, EvalContext};
Expand All @@ -34,40 +35,32 @@ macro_rules! define_eval {
($O: ident) => {
paste! {
fn [<eval_ $O>](columns: &[VectorRef]) -> Result<VectorRef> {
with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
with_match_primitive_type_id!(columns[1].data_type().logical_type_id(), |$T| {
with_match_primitive_type_id!(columns[2].data_type().logical_type_id(), |$R| {
// clip(a, min, max) is equals to min(max(a, min), max)
let col: VectorRef = Arc::new(scalar_binary_op::<
<$S as LogicalPrimitiveType>::Wrapper,
<$T as LogicalPrimitiveType>::Wrapper,
$O,
_,
>(
&columns[0],
&columns[1],
scalar_max,
&mut EvalContext::default(),
)?);
let col = scalar_binary_op::<$O, <$R as LogicalPrimitiveType>::Wrapper, $O, _>(
&col,
&columns[2],
scalar_min,
&mut EvalContext::default(),
)?;
Ok(Arc::new(col))
}, {
unreachable!()
})
}, {
unreachable!()
})
}, {
unreachable!()
})
fn cast_vector(input: &VectorRef) -> VectorRef {
Arc::new(PrimitiveVector::<<$O as WrapperType>::LogicalType>::try_from_arrow_array(
compute::cast(&input.to_arrow_array(), &<<<$O as WrapperType>::LogicalType as LogicalPrimitiveType>::ArrowPrimitive as ArrowPrimitiveType>::DATA_TYPE).unwrap()
).unwrap()) as _
}
let operator_1 = cast_vector(&columns[0]);
let operator_2 = cast_vector(&columns[1]);
let operator_3 = cast_vector(&columns[2]);

// clip(a, min, max) is equals to min(max(a, min), max)
let col: VectorRef = Arc::new(scalar_binary_op::<$O, $O, $O, _>(
&operator_1,
&operator_2,
scalar_max,
&mut EvalContext::default(),
)?);
let col = scalar_binary_op::<$O, $O, $O, _>(
&col,
&operator_3,
scalar_min,
&mut EvalContext::default(),
)?;
Ok(Arc::new(col))
}
}
}
};
}

define_eval!(i64);
Expand Down Expand Up @@ -123,27 +116,23 @@ pub fn max<T: PartialOrd>(input: T, max: T) -> T {
}

#[inline]
fn scalar_min<S, T, O>(left: Option<S>, right: Option<T>, _ctx: &mut EvalContext) -> Option<O>
fn scalar_min<O>(left: Option<O>, right: Option<O>, _ctx: &mut EvalContext) -> Option<O>
where
S: AsPrimitive<O>,
T: AsPrimitive<O>,
O: Scalar + Copy + PartialOrd,
{
match (left, right) {
(Some(left), Some(right)) => Some(min(left.as_(), right.as_())),
(Some(left), Some(right)) => Some(min(left, right)),
_ => None,
}
}

#[inline]
fn scalar_max<S, T, O>(left: Option<S>, right: Option<T>, _ctx: &mut EvalContext) -> Option<O>
fn scalar_max<O>(left: Option<O>, right: Option<O>, _ctx: &mut EvalContext) -> Option<O>
where
S: AsPrimitive<O>,
T: AsPrimitive<O>,
O: Scalar + Copy + PartialOrd,
{
match (left, right) {
(Some(left), Some(right)) => Some(max(left.as_(), right.as_())),
(Some(left), Some(right)) => Some(max(left, right)),
_ => None,
}
}
Expand Down

0 comments on commit 95b2d86

Please sign in to comment.