Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-6668: [Rust] [DataFusion] Implement CAST expression #5477

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions rust/datafusion/src/execution/physical_plan/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use arrow::array::{
Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder,
Int8Builder, UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder,
};
use arrow::compute::kernels::cast::cast;
use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;

Expand Down Expand Up @@ -196,6 +197,61 @@ pub fn sum(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
Arc::new(Sum::new(expr))
}

/// CAST expression casts an expression to a specific data type
pub struct CastExpr {
/// The expression to cast
expr: Arc<dyn PhysicalExpr>,
/// The data type to cast to
cast_type: DataType,
}

/// Determine if a DataType is numeric or not
fn is_numeric(dt: &DataType) -> bool {
match dt {
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => true,
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => true,
DataType::Float16 | DataType::Float32 | DataType::Float64 => true,
_ => false,
}
}

impl CastExpr {
/// Create a CAST expression
pub fn try_new(
expr: Arc<dyn PhysicalExpr>,
input_schema: &Schema,
cast_type: DataType,
) -> Result<Self> {
let expr_type = expr.data_type(input_schema)?;
// numbers can be cast to numbers and strings
if is_numeric(&expr_type)
&& (is_numeric(&cast_type) || cast_type == DataType::Utf8)
{
Ok(Self { expr, cast_type })
} else {
Err(ExecutionError::General(format!(
"Invalid CAST from {:?} to {:?}",
expr_type, cast_type
)))
}
}
}

impl PhysicalExpr for CastExpr {
fn name(&self) -> String {
"CAST".to_string()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: there is inconsistency in upper vs lower case for name. The literal one was lower case, this one and SUM were upper case.

}

fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(self.cast_type.clone())
}

fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
let value = self.expr.evaluate(batch)?;
Ok(cast(&value, &self.cast_type)?)
}
}

/// Represents a non-null literal value
pub struct Literal {
value: ScalarValue,
Expand Down Expand Up @@ -276,6 +332,7 @@ pub fn lit(value: ScalarValue) -> Arc<dyn PhysicalExpr> {
mod tests {
use super::*;
use crate::error::Result;
use arrow::array::BinaryArray;
use arrow::datatypes::*;

#[test]
Expand All @@ -299,6 +356,56 @@ mod tests {
Ok(())
}

#[test]
fn cast_i32_to_u32() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

let cast = CastExpr::try_new(col(0), &schema, DataType::UInt32)?;
let result = cast.evaluate(&batch)?;
assert_eq!(result.len(), 5);

let result = result
.as_any()
.downcast_ref::<UInt32Array>()
.expect("failed to downcast to UInt32Array");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious about how the API is designed to work here. Why don't you handle the downcast in evaluate based on data_type()? Is it just that you don't want the user to always incur the cost of the match on data_type().

assert_eq!(result.value(0), 1_u32);

Ok(())
}

#[test]
fn cast_i32_to_utf8() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

let cast = CastExpr::try_new(col(0), &schema, DataType::Utf8)?;
let result = cast.evaluate(&batch)?;
assert_eq!(result.len(), 5);

let result = result
.as_any()
.downcast_ref::<BinaryArray>()
.expect("failed to downcast to BinaryArray");
assert_eq!(result.value(0), "1".as_bytes());

Ok(())
}

#[test]
fn invalid_cast() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
match CastExpr::try_new(col(0), &schema, DataType::Int32) {
Err(ExecutionError::General(ref str)) => {
assert_eq!(str, "Invalid CAST from Utf8 to Int32");
Ok(())
}
_ => panic!(),
}
}

#[test]
fn sum_contract() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
Expand Down