diff --git a/.gitmodules b/.gitmodules index ec5d6208b8dd..5d0594c0c75f 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,4 +3,4 @@ url = https://github.com/apache/parquet-testing.git [submodule "testing"] path = testing - url = https://github.com/apache/arrow-testing + url = https://github.com/apache/arrow-testing \ No newline at end of file diff --git a/ballista-examples/Cargo.toml b/ballista-examples/Cargo.toml index 338f69994bfd..09923d54df82 100644 --- a/ballista-examples/Cargo.toml +++ b/ballista-examples/Cargo.toml @@ -29,7 +29,7 @@ publish = false rust-version = "1.57" [dependencies] -datafusion = { path = "../datafusion" } +datafusion = { path = "../datafusion" } ballista = { path = "../ballista/rust/client", version = "0.6.0"} prost = "0.9" tonic = "0.6" diff --git a/ballista/rust/client/Cargo.toml b/ballista/rust/client/Cargo.toml index 7736e949d29f..7e3006bfcf2a 100644 --- a/ballista/rust/client/Cargo.toml +++ b/ballista/rust/client/Cargo.toml @@ -27,14 +27,14 @@ edition = "2021" rust-version = "1.57" [dependencies] -ballista-core = { path = "../core", version = "0.6.0" } -ballista-executor = { path = "../executor", version = "0.6.0", optional = true } -ballista-scheduler = { path = "../scheduler", version = "0.6.0", optional = true } +ballista-core = { path = "../core"} +ballista-executor = { path = "../executor", optional = true } +ballista-scheduler = { path = "../scheduler", optional = true } futures = "0.3" log = "0.4" tokio = "1.0" -datafusion = { path = "../../../datafusion", version = "6.0.0" } +datafusion = { path = "../../../datafusion" } [features] default = [] diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml index bbf8e274c5cd..15cb530130dc 100644 --- a/ballista/rust/core/Cargo.toml +++ b/ballista/rust/core/Cargo.toml @@ -42,10 +42,9 @@ tokio = "1.0" tonic = "0.6" uuid = { version = "0.8", features = ["v4"] } chrono = { version = "0.4", default-features = false } - arrow-flight = { version = "7.0.0" } -datafusion = { path = "../../../datafusion", version = "6.0.0" } +datafusion = { path = "../../../datafusion" } [dev-dependencies] tempfile = "3" diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index cc6e00aa939f..699ebbf43262 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -80,6 +80,11 @@ message LogicalExprNode { // window expressions WindowExprNode window_expr = 18; + + //ArgoEngineAggregateUDF expressions + AggregateUDFExprNode aggregate_udf_expr = 19; + + ScalarUDFProtoExprNode scalar_udf_proto_expr = 20; } } @@ -517,6 +522,12 @@ message PhysicalExprNode { // window expressions PhysicalWindowExprNode window_expr = 15; + + // argo engine add. + PhysicalAggregateUDFExprNode aggregate_udf_expr = 16; + + PhysicalScalarUDFProtoExprNode scalar_udf_proto_expr = 17; + // argo engine add end. } } @@ -525,6 +536,19 @@ message PhysicalAggregateExprNode { PhysicalExprNode expr = 2; } +// argo engine add. +message PhysicalAggregateUDFExprNode { + string fun_name = 1; + repeated PhysicalExprNode expr = 2; +} + +message PhysicalScalarUDFProtoExprNode { + string fun_name = 1; + repeated PhysicalExprNode expr = 2; + ArrowType return_type = 3; +} +// argo engine add end. + message PhysicalWindowExprNode { oneof window_function { AggregateFunction aggr_function = 1; @@ -976,6 +1000,19 @@ service SchedulerGrpc { rpc GetJobStatus (GetJobStatusParams) returns (GetJobStatusResult) {} } +/////////////////////////////////////////////////////////////////////////////////////////////////// +// ArgoEngine add. +/////////////////////////////////////////////////////////////////////////////////////////////////// +message AggregateUDFExprNode { + string fun_name = 1; + repeated LogicalExprNode args = 2; +} + +message ScalarUDFProtoExprNode { + string fun_name = 1; + repeated LogicalExprNode args = 2; +} + /////////////////////////////////////////////////////////////////////////////////////////////////// // Arrow Data Types /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1085,9 +1122,22 @@ message ScalarValue{ ScalarType null_list_value = 18; PrimitiveScalarType null_value = 19; + Decimal128 decimal128_value = 20; + int64 date_64_value = 21; + int64 time_second_value = 22; + int64 time_millisecond_value = 23; + int32 interval_yearmonth_value = 24; + int64 interval_daytime_value = 25; + } } +message Decimal128{ + bytes value = 1; + int64 p = 2; + int64 s = 3; +} + // Contains all valid datafusion scalar type except for // List enum PrimitiveScalarType{ @@ -1109,6 +1159,13 @@ enum PrimitiveScalarType{ TIME_MICROSECOND = 14; TIME_NANOSECOND = 15; NULL = 16; + DECIMAL128 = 17; + DATE64 = 20; + TIME_SECOND = 21; + TIME_MILLISECOND = 22; + INTERVAL_YEARMONTH = 23; + INTERVAL_DAYTIME = 24; + } message ScalarType{ diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index dfac547d7bb3..d8c6f5ee6d82 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -18,7 +18,9 @@ //! Serde code to convert from protocol buffers to Rust data structures. use crate::error::BallistaError; -use crate::serde::{from_proto_binary_op, proto_error, protobuf, str_to_byte}; +use crate::serde::{ + from_proto_binary_op, proto_error, protobuf, str_to_byte, vec_to_array, +}; use crate::{convert_box_required, convert_required}; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::datasource::file_format::avro::AvroFormat; @@ -540,12 +542,55 @@ fn typechecked_scalar_value_conversion( "Untyped scalar null is not a valid scalar value", )) } + + // argo engine add. + PrimitiveScalarType::Decimal128 => { + ScalarValue::Decimal128(None, 0, 0) + } + PrimitiveScalarType::Date64 => ScalarValue::Date64(None), + PrimitiveScalarType::TimeSecond => { + ScalarValue::TimestampSecond(None, None) + } + PrimitiveScalarType::TimeMillisecond => { + ScalarValue::TimestampMillisecond(None, None) + } + PrimitiveScalarType::IntervalYearmonth => { + ScalarValue::IntervalYearMonth(None) + } + PrimitiveScalarType::IntervalDaytime => { + ScalarValue::IntervalDayTime(None) + } // argo engine add end. }; scalar_value } else { return Err(proto_error("Could not convert to the proper type")); } } + + // argo engine add. + (Value::Decimal128Value(val), PrimitiveScalarType::Decimal128) => { + let array = vec_to_array(val.value.clone()); + ScalarValue::Decimal128( + Some(i128::from_be_bytes(array)), + val.p as usize, + val.s as usize, + ) + } + (Value::Date64Value(v), PrimitiveScalarType::Date64) => { + ScalarValue::Date64(Some(*v)) + } + (Value::TimeSecondValue(v), PrimitiveScalarType::TimeSecond) => { + ScalarValue::TimestampSecond(Some(*v), None) + } + (Value::TimeMillisecondValue(v), PrimitiveScalarType::TimeMillisecond) => { + ScalarValue::TimestampMillisecond(Some(*v), None) + } + (Value::IntervalYearmonthValue(v), PrimitiveScalarType::IntervalYearmonth) => { + ScalarValue::IntervalYearMonth(Some(*v)) + } + (Value::IntervalDaytimeValue(v), PrimitiveScalarType::IntervalDaytime) => { + ScalarValue::IntervalDayTime(Some(*v)) + } // argo engine add end. _ => return Err(proto_error("Could not convert to the proper type")), }) } @@ -607,6 +652,31 @@ impl TryInto for &protobuf::scalar_value::Value .ok_or_else(|| proto_error("Invalid scalar type"))? .try_into()? } + + //argo engine add. + protobuf::scalar_value::Value::Decimal128Value(val) => { + let array = vec_to_array(val.value.clone()); + ScalarValue::Decimal128( + Some(i128::from_be_bytes(array)), + val.p as usize, + val.s as usize, + ) + } + protobuf::scalar_value::Value::Date64Value(v) => { + ScalarValue::Date64(Some(*v)) + } + protobuf::scalar_value::Value::TimeSecondValue(v) => { + ScalarValue::TimestampSecond(Some(*v), None) + } + protobuf::scalar_value::Value::TimeMillisecondValue(v) => { + ScalarValue::TimestampMillisecond(Some(*v), None) + } + protobuf::scalar_value::Value::IntervalYearmonthValue(v) => { + ScalarValue::IntervalYearMonth(Some(*v)) + } + protobuf::scalar_value::Value::IntervalDaytimeValue(v) => { + ScalarValue::IntervalDayTime(Some(*v)) + } //argo engine add end. }; Ok(scalar) } @@ -763,6 +833,23 @@ impl TryInto for protobuf::PrimitiveScalarType protobuf::PrimitiveScalarType::TimeNanosecond => { ScalarValue::TimestampNanosecond(None, None) } + // argo engine add. + protobuf::PrimitiveScalarType::Decimal128 => { + ScalarValue::Decimal128(None, 0, 0) + } + protobuf::PrimitiveScalarType::Date64 => ScalarValue::Date64(None), + protobuf::PrimitiveScalarType::TimeSecond => { + ScalarValue::TimestampSecond(None, None) + } + protobuf::PrimitiveScalarType::TimeMillisecond => { + ScalarValue::TimestampMillisecond(None, None) + } + protobuf::PrimitiveScalarType::IntervalYearmonth => { + ScalarValue::IntervalYearMonth(None) + } + protobuf::PrimitiveScalarType::IntervalDaytime => { + ScalarValue::IntervalDayTime(None) + } // argo engine add end. }) } } @@ -845,6 +932,31 @@ impl TryInto for &protobuf::ScalarValue { .ok_or_else(|| proto_error("Protobuf deserialization error found invalid enum variant for DatafusionScalar"))?; null_type_enum.try_into()? } + + //argo engine add. + protobuf::scalar_value::Value::Decimal128Value(val) => { + let array = vec_to_array(val.value.clone()); + ScalarValue::Decimal128( + Some(i128::from_be_bytes(array)), + val.p as usize, + val.s as usize, + ) + } + protobuf::scalar_value::Value::Date64Value(v) => { + ScalarValue::Date64(Some(*v)) + } + protobuf::scalar_value::Value::TimeSecondValue(v) => { + ScalarValue::TimestampSecond(Some(*v), None) + } + protobuf::scalar_value::Value::TimeMillisecondValue(v) => { + ScalarValue::TimestampMillisecond(Some(*v), None) + } + protobuf::scalar_value::Value::IntervalYearmonthValue(v) => { + ScalarValue::IntervalYearMonth(Some(*v)) + } + protobuf::scalar_value::Value::IntervalDaytimeValue(v) => { + ScalarValue::IntervalDayTime(Some(*v)) + } //argo engine add end. }) } } @@ -966,6 +1078,59 @@ impl TryInto for &protobuf::LogicalExprNode { distinct: false, //TODO }) } + // argo engine add start + ExprType::AggregateUdfExpr(expr) => { + let gpm = global_plugin_manager("").lock().unwrap(); + let plugin_registrar = gpm.plugin_managers.get(&PluginEnum::UDF).unwrap(); + if let Some(udf_plugin_manager) = + plugin_registrar.as_any().downcast_ref::() + { + let fun = udf_plugin_manager + .aggregate_udfs + .get(expr.fun_name.as_str()) + .ok_or_else(|| { + proto_error(format!( + "can not get udaf:{} from udf_plugins!", + expr.fun_name.to_string() + )) + })?; + let fun_arc = fun.clone(); + let fun_args = &expr.args; + let args: Vec = fun_args + .iter() + .map(|e| e.try_into()) + .collect::, BallistaError>>()?; + Ok(Expr::AggregateUDF { fun: fun_arc, args }) + } else { + Err(proto_error("can not get udf plugin".to_string())) + } + } + ExprType::ScalarUdfProtoExpr(expr) => { + let gpm = global_plugin_manager("").lock().unwrap(); + let plugin_registrar = gpm.plugin_managers.get(&PluginEnum::UDF).unwrap(); + if let Some(udf_plugin_manager) = + plugin_registrar.as_any().downcast_ref::() + { + let fun = udf_plugin_manager + .scalar_udfs + .get(expr.fun_name.as_str()) + .ok_or_else(|| { + proto_error(format!( + "can not get udf:{} from udf_plugins!", + expr.fun_name.to_string() + )) + })?; + let fun_arc = fun.clone(); + let fun_args = &expr.args; + let args: Vec = fun_args + .iter() + .map(|e| e.try_into()) + .collect::, BallistaError>>()?; + Ok(Expr::ScalarUDF { fun: fun_arc, args }) + } else { + Err(proto_error(format!("can not found udf plugins!"))) + } + } // argo engine add end ExprType::Alias(alias) => Ok(Expr::Alias( Box::new(parse_required_expr(&alias.expr)?), alias.alias.clone(), @@ -1169,10 +1334,14 @@ impl TryInto for &protobuf::Field { use crate::serde::protobuf::ColumnStats; use datafusion::physical_plan::{aggregates, windows}; +use datafusion::plugin::plugin_manager::global_plugin_manager; +use datafusion::plugin::udf::UDFPluginManager; +use datafusion::plugin::PluginEnum; use datafusion::prelude::{ array, date_part, date_trunc, length, lower, ltrim, md5, rtrim, sha224, sha256, sha384, sha512, trim, upper, }; +use futures::TryFutureExt; use std::convert::TryFrom; impl TryFrom for protobuf::FileType { diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 36e9ba69ed5a..dcaf9b7f5eb8 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -41,6 +41,7 @@ use datafusion::logical_plan::{ }; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::functions::BuiltinScalarFunction; +use datafusion::physical_plan::udaf::AggregateUDF; use datafusion::physical_plan::window_functions::{ BuiltInWindowFunction, WindowFunction, }; @@ -50,6 +51,7 @@ use protobuf::{ arrow_type, logical_expr_node::ExprType, scalar_type, DateUnit, PrimitiveScalarType, ScalarListValue, ScalarType, }; +use std::sync::Arc; use std::{ boxed, convert::{TryFrom, TryInto}, @@ -563,6 +565,54 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for protobuf::ScalarValue { Value::TimeNanosecondValue(*s) }) } + + // argo engine add. + datafusion::scalar::ScalarValue::Decimal128(val, p, s) => { + match *val { + Some(v) => { + let array = v.to_be_bytes(); + let vec_val: Vec = array.to_vec(); + protobuf::ScalarValue { + value: Some(Value::Decimal128Value(protobuf::Decimal128 { + value: vec_val, + p: *p as i64, + s: *s as i64, + })), + } + } + None => { + protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::NullValue(PrimitiveScalarType::Decimal128 as i32)) + } + } + } + } + datafusion::scalar::ScalarValue::Date64(val) => { + create_proto_scalar(val, PrimitiveScalarType::Date64, |s| { + Value::Date64Value(*s) + }) + } + datafusion::scalar::ScalarValue::TimestampSecond(val, _) => { + create_proto_scalar(val, PrimitiveScalarType::TimeSecond, |s| { + Value::TimeSecondValue(*s) + }) + } + datafusion::scalar::ScalarValue::TimestampMillisecond(val, _) => { + create_proto_scalar(val, PrimitiveScalarType::TimeMillisecond, |s| { + Value::TimeMillisecondValue(*s) + }) + } + datafusion::scalar::ScalarValue::IntervalYearMonth(val) => { + create_proto_scalar(val, PrimitiveScalarType::IntervalYearmonth, |s| { + Value::IntervalYearmonthValue(*s) + }) + } + datafusion::scalar::ScalarValue::IntervalDayTime(val) => { + create_proto_scalar(val, PrimitiveScalarType::IntervalDaytime, |s| { + Value::IntervalDaytimeValue(*s) + }) + } + // argo engine add end. _ => { return Err(proto_error(format!( "Error converting to Datatype to scalar type, {:?} is invalid as a datafusion scalar.", @@ -1078,8 +1128,40 @@ impl TryInto for &Expr { ), }) } - Expr::ScalarUDF { .. } => unimplemented!(), - Expr::AggregateUDF { .. } => unimplemented!(), + // argo engine add start + Expr::ScalarUDF { ref fun, ref args } => { + let args: Vec = args + .iter() + .map(|e| e.try_into()) + .collect::, BallistaError>>()?; + Ok(protobuf::LogicalExprNode { + expr_type: Some( + protobuf::logical_expr_node::ExprType::ScalarUdfProtoExpr( + protobuf::ScalarUdfProtoExprNode { + fun_name: fun.name.clone(), + args, + }, + ), + ), + }) + } + Expr::AggregateUDF { ref fun, ref args } => { + let args: Vec = args + .iter() + .map(|e| e.try_into()) + .collect::, BallistaError>>()?; + Ok(protobuf::LogicalExprNode { + expr_type: Some( + protobuf::logical_expr_node::ExprType::AggregateUdfExpr( + protobuf::AggregateUdfExprNode { + fun_name: fun.name.clone(), + args, + }, + ), + ), + }) + } + // argo engine add end Expr::Not(expr) => { let expr = Box::new(protobuf::Not { expr: Some(Box::new(expr.as_ref().try_into()?)), diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index 62246a0232df..a8a0deb25054 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -20,7 +20,7 @@ use std::{convert::TryInto, io::Cursor}; -use datafusion::arrow::datatypes::UnionMode; +use datafusion::arrow::datatypes::{IntervalUnit, UnionMode}; use datafusion::logical_plan::{JoinConstraint, JoinType, Operator}; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::window_functions::BuiltInWindowFunction; @@ -314,6 +314,21 @@ impl Into for protobuf::PrimitiveScalarT DataType::Time64(TimeUnit::Nanosecond) } protobuf::PrimitiveScalarType::Null => DataType::Null, + // argo engine add. + protobuf::PrimitiveScalarType::Decimal128 => DataType::Decimal(0, 0), + protobuf::PrimitiveScalarType::Date64 => DataType::Date64, + protobuf::PrimitiveScalarType::TimeSecond => { + DataType::Timestamp(TimeUnit::Second, None) + } + protobuf::PrimitiveScalarType::TimeMillisecond => { + DataType::Timestamp(TimeUnit::Millisecond, None) + } + protobuf::PrimitiveScalarType::IntervalYearmonth => { + DataType::Interval(IntervalUnit::YearMonth) + } + protobuf::PrimitiveScalarType::IntervalDaytime => { + DataType::Interval(IntervalUnit::DayTime) + } // argo engine add end. } } } @@ -375,3 +390,9 @@ fn str_to_byte(s: &str) -> Result { } Ok(s.as_bytes()[0]) } + +fn vec_to_array(v: Vec) -> [T; N] { + v.try_into().unwrap_or_else(|v: Vec| { + panic!("Expected a Vec of length {} but it was {}", N, v.len()) + }) +} diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index cad27b315645..72a0455ea396 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -31,6 +31,7 @@ use crate::serde::scheduler::PartitionLocation; use crate::serde::{from_proto_binary_op, proto_error, protobuf, str_to_byte}; use crate::{convert_box_required, convert_required, into_required}; use chrono::{TimeZone, Utc}; +use datafusion::arrow::compute::eq_dyn; use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::catalog::catalog::{ CatalogList, CatalogProvider, MemoryCatalogList, MemoryCatalogProvider, @@ -55,6 +56,8 @@ use datafusion::physical_plan::hash_join::PartitionMode; use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::physical_plan::planner::DefaultPhysicalPlanner; use datafusion::physical_plan::sorts::sort::{SortExec, SortOptions}; +use datafusion::physical_plan::udaf::create_aggregate_expr as create_aggregate_udf_expr; +use datafusion::physical_plan::udf::{create_physical_expr, ScalarUDFExpr}; use datafusion::physical_plan::window_functions::{ BuiltInWindowFunction, WindowFunction, }; @@ -79,6 +82,9 @@ use datafusion::physical_plan::{ use datafusion::physical_plan::{ AggregateExpr, ColumnStatistics, ExecutionPlan, PhysicalExpr, Statistics, WindowExpr, }; +use datafusion::plugin::plugin_manager::global_plugin_manager; +use datafusion::plugin::udf::UDFPluginManager; +use datafusion::plugin::PluginEnum; use datafusion::prelude::CsvReadOptions; use log::debug; use protobuf::physical_expr_node::ExprType; @@ -311,6 +317,37 @@ impl TryInto> for &protobuf::PhysicalPlanNode { name.to_string(), )?) } + ExprType::AggregateUdfExpr(agg_node) => { + let name = agg_node.fun_name.as_str(); + let udaf_fun_name = &name[0..name.find('(').unwrap()]; + let gpm = global_plugin_manager("").lock().unwrap(); + let plugin_registrar = gpm.plugin_managers.get(&PluginEnum::UDF).unwrap(); + if let Some(udf_plugin_manager) = plugin_registrar.as_any().downcast_ref::() + { + let fun = udf_plugin_manager.aggregate_udfs.get(udaf_fun_name).ok_or_else(|| { + proto_error(format!( + "can not get udaf:{} from plugins!", + udaf_fun_name.to_string() + )) + })?; + let aggregate_udf = &*fun.clone(); + let args: Vec> = agg_node.expr + .iter() + .map(|e| e.try_into()) + .collect::, BallistaError>>()?; + + Ok(create_aggregate_udf_expr( + aggregate_udf, + &args, + &physical_schema, + name.to_string(), + )?) + } else { + Err(proto_error(format!( + "can not found udf plugin!" + ))) + } + } // argo engine add end. _ => Err(BallistaError::General( "Invalid aggregate expression for HashAggregateExec" .to_string(), @@ -545,6 +582,46 @@ impl TryFrom<&protobuf::PhysicalExprNode> for Arc { .to_owned(), )); } + // argo engine add. + ExprType::ScalarUdfProtoExpr(e) => { + let gpm = global_plugin_manager("").lock().unwrap(); + let plugin_registrar = gpm.plugin_managers.get(&PluginEnum::UDF).unwrap(); + if let Some(udf_plugin_manager) = + plugin_registrar.as_any().downcast_ref::() + { + let fun = udf_plugin_manager + .scalar_udfs + .get(&e.fun_name) + .ok_or_else(|| { + proto_error(format!( + "can not get udf:{} from plugin!", + &e.fun_name.to_owned() + )) + })?; + + let scalar_udf = &*fun.clone(); + let args = e + .expr + .iter() + .map(|x| x.try_into()) + .collect::, _>>()?; + + Arc::new(ScalarUDFExpr::new( + e.fun_name.as_str(), + scalar_udf.clone(), + args, + &convert_required!(e.return_type)?, + )) + } else { + return Err(proto_error(format!("can not found plugin!"))); + } + } + ExprType::AggregateUdfExpr(_) => { + return Err(BallistaError::General( + "Cannot convert aggregate udf expr node to physical expression" + .to_owned(), + )); + } // argo engine add end. ExprType::WindowExpr(_) => { return Err(BallistaError::General( "Cannot convert window expr node to physical expression".to_owned(), diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index 930f0757e202..357fb9f1910c 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -74,6 +74,8 @@ use crate::{ use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::functions::{BuiltinScalarFunction, ScalarFunctionExpr}; use datafusion::physical_plan::repartition::RepartitionExec; +use datafusion::physical_plan::udaf::AggregateFunctionExpr; +use datafusion::physical_plan::udf::{ScalarUDF, ScalarUDFExpr}; impl TryInto for Arc { type Error = BallistaError; @@ -412,35 +414,54 @@ impl TryInto for Arc { type Error = BallistaError; fn try_into(self) -> Result { - let aggr_function = if self.as_any().downcast_ref::().is_some() { - Ok(protobuf::AggregateFunction::Avg.into()) - } else if self.as_any().downcast_ref::().is_some() { - Ok(protobuf::AggregateFunction::Sum.into()) - } else if self.as_any().downcast_ref::().is_some() { - Ok(protobuf::AggregateFunction::Count.into()) - } else if self.as_any().downcast_ref::().is_some() { - Ok(protobuf::AggregateFunction::Min.into()) - } else if self.as_any().downcast_ref::().is_some() { - Ok(protobuf::AggregateFunction::Max.into()) - } else { - Err(BallistaError::NotImplemented(format!( - "Aggregate function not supported: {:?}", - self - ))) - }?; + // argo engine add. + // aggregate udf let expressions: Vec = self .expressions() .iter() .map(|e| e.clone().try_into()) .collect::, BallistaError>>()?; - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( - Box::new(protobuf::PhysicalAggregateExprNode { - aggr_function, - expr: Some(Box::new(expressions[0].clone())), - }), - )), - }) + if self + .as_any() + .downcast_ref::() + .is_some() + { + Ok(protobuf::PhysicalExprNode { + expr_type: Some( + protobuf::physical_expr_node::ExprType::AggregateUdfExpr( + protobuf::PhysicalAggregateUdfExprNode { + fun_name: self.name().to_string(), + expr: expressions, + }, + ), + ), + }) + } else { + let aggr_function = if self.as_any().downcast_ref::().is_some() { + Ok(protobuf::AggregateFunction::Avg.into()) + } else if self.as_any().downcast_ref::().is_some() { + Ok(protobuf::AggregateFunction::Sum.into()) + } else if self.as_any().downcast_ref::().is_some() { + Ok(protobuf::AggregateFunction::Count.into()) + } else if self.as_any().downcast_ref::().is_some() { + Ok(protobuf::AggregateFunction::Min.into()) + } else if self.as_any().downcast_ref::().is_some() { + Ok(protobuf::AggregateFunction::Max.into()) + } else { + Err(BallistaError::NotImplemented(format!( + "Aggregate function not supported: {:?}", + self + ))) + }?; + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( + Box::new(protobuf::PhysicalAggregateExprNode { + aggr_function, + expr: Some(Box::new(expressions[0].clone())), + }), + )), + }) + } } } @@ -597,6 +618,25 @@ impl TryFrom> for protobuf::PhysicalExprNode { }, )), }) + } else if let Some(expr) = expr.downcast_ref::() { + let args: Vec = expr + .args() + .iter() + .map(|e| e.to_owned().try_into()) + .collect::, _>>()?; + let data_type = expr.return_type().clone(); + let return_type = (&data_type).into(); + Ok(protobuf::PhysicalExprNode { + expr_type: Some( + protobuf::physical_expr_node::ExprType::ScalarUdfProtoExpr( + protobuf::PhysicalScalarUdfProtoExprNode { + fun_name: expr.name().to_string(), + expr: args, + return_type: Some(return_type), + }, + ), + ), + }) } else { Err(BallistaError::General(format!( "physical_plan::to_proto() unsupported expression {:?}", diff --git a/ballista/rust/executor/Cargo.toml b/ballista/rust/executor/Cargo.toml index c01bb20681db..d0b4186ec00c 100644 --- a/ballista/rust/executor/Cargo.toml +++ b/ballista/rust/executor/Cargo.toml @@ -33,9 +33,9 @@ arrow = { version = "7.0.0" } arrow-flight = { version = "7.0.0" } anyhow = "1" async-trait = "0.1.36" -ballista-core = { path = "../core", version = "0.6.0" } +ballista-core = { path = "../core"} configure_me = "0.4.0" -datafusion = { path = "../../../datafusion", version = "6.0.0" } +datafusion = { path = "../../../datafusion" } env_logger = "0.9" futures = "0.3" log = "0.4" diff --git a/ballista/rust/scheduler/Cargo.toml b/ballista/rust/scheduler/Cargo.toml index 0bacccf031d8..3799befffa03 100644 --- a/ballista/rust/scheduler/Cargo.toml +++ b/ballista/rust/scheduler/Cargo.toml @@ -32,10 +32,10 @@ sled = ["sled_package", "tokio-stream"] [dependencies] anyhow = "1" -ballista-core = { path = "../core", version = "0.6.0" } +ballista-core = { path = "../core"} clap = "2" configure_me = "0.4.0" -datafusion = { path = "../../../datafusion", version = "6.0.0" } +datafusion = { path = "../../../datafusion" } env_logger = "0.9" etcd-client = { version = "0.7", optional = true } futures = "0.3" @@ -55,7 +55,7 @@ tower = { version = "0.4" } warp = "0.3" [dev-dependencies] -ballista-core = { path = "../core", version = "0.6.0" } +ballista-core = { path = "../core"} uuid = { version = "0.8", features = ["v4"] } [build-dependencies] diff --git a/ballista/rust/scheduler/src/state/mod.rs b/ballista/rust/scheduler/src/state/mod.rs index ef6de8312702..535ced9d85df 100644 --- a/ballista/rust/scheduler/src/state/mod.rs +++ b/ballista/rust/scheduler/src/state/mod.rs @@ -320,34 +320,37 @@ impl SchedulerState { .await?; if task_is_dead { continue 'tasks; - } else if let Some(task_status::Status::Completed( - CompletedTask { + } + + match &referenced_task.status { + Some(task_status::Status::Completed(CompletedTask { executor_id, partitions, - }, - )) = &referenced_task.status - { - debug!("Task for unresolved shuffle input partition {} completed and produced these shuffle partitions:\n\t{}", - shuffle_input_partition_id, - partitions.iter().map(|p| format!("{}={}", p.partition_id, &p.path)).collect::>().join("\n\t") - ); - let stage_shuffle_partition_locations = partition_locations - .entry(unresolved_shuffle.stage_id) - .or_insert_with(HashMap::new); - let executor_meta = executors - .iter() - .find(|exec| exec.id == *executor_id) - .unwrap() - .clone(); - - for shuffle_write_partition in partitions { - let temp = stage_shuffle_partition_locations - .entry(shuffle_write_partition.partition_id as usize) - .or_insert_with(Vec::new); - let executor_meta = executor_meta.clone(); - let partition_location = - ballista_core::serde::scheduler::PartitionLocation { - partition_id: + })) => { + debug!("Task for unresolved shuffle input partition {} completed and produced these shuffle partitions:\n\t{}", + shuffle_input_partition_id, + partitions.iter().map(|p| format!("{}={}", p.partition_id, &p.path)).collect::>().join("\n\t") + ); + let stage_shuffle_partition_locations = + partition_locations + .entry(unresolved_shuffle.stage_id) + .or_insert_with(HashMap::new); + let executor_meta = executors + .iter() + .find(|exec| exec.id == *executor_id) + .unwrap() + .clone(); + + for shuffle_write_partition in partitions { + let temp = stage_shuffle_partition_locations + .entry( + shuffle_write_partition.partition_id as usize, + ) + .or_insert_with(Vec::new); + let executor_meta = executor_meta.clone(); + let partition_location = + ballista_core::serde::scheduler::PartitionLocation { + partition_id: ballista_core::serde::scheduler::PartitionId { job_id: partition.job_id.clone(), stage_id: unresolved_shuffle.stage_id, @@ -355,29 +358,44 @@ impl SchedulerState { .partition_id as usize, }, - executor_meta, - partition_stats: PartitionStats::new( - Some(shuffle_write_partition.num_rows), - Some(shuffle_write_partition.num_batches), - Some(shuffle_write_partition.num_bytes), - ), - path: shuffle_write_partition.path.clone(), - }; + executor_meta, + partition_stats: PartitionStats::new( + Some(shuffle_write_partition.num_rows), + Some(shuffle_write_partition.num_batches), + Some(shuffle_write_partition.num_bytes), + ), + path: shuffle_write_partition.path.clone(), + }; + + debug!( + "Scheduler storing stage {} output partition {} path: {}", + unresolved_shuffle.stage_id, + partition_location.partition_id.partition_id, + partition_location.path + ); + temp.push(partition_location); + } + } + Some(task_status::Status::Failed(FailedTask { error })) => { + // A task should fail when its referenced_task fails + let mut status = status.clone(); + let err_msg = format!("{}", error); + status.status = + Some(task_status::Status::Failed(FailedTask { + error: err_msg, + })); + self.save_task_status(&status).await?; + continue 'tasks; + } + _ => { debug!( - "Scheduler storing stage {} output partition {} path: {}", + "Stage {} input partition {} has not completed yet", unresolved_shuffle.stage_id, - partition_location.partition_id.partition_id, - partition_location.path - ); - temp.push(partition_location); + shuffle_input_partition_id, + ); + continue 'tasks; } - } else { - debug!( - "Stage {} input partition {} has not completed yet", - unresolved_shuffle.stage_id, shuffle_input_partition_id, - ); - continue 'tasks; - } + }; } } diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index d20de3106bd3..d2e4b3143ec0 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -32,7 +32,7 @@ simd = ["datafusion/simd"] snmalloc = ["snmalloc-rs"] [dependencies] -datafusion = { path = "../datafusion" } +datafusion = { path = "../datafusion" } ballista = { path = "../ballista/rust/client" } structopt = { version = "0.3", default-features = false } tokio = { version = "^1.0", features = ["macros", "rt", "rt-multi-thread"] } diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index d5347d8e0009..97215e254fa0 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -30,6 +30,6 @@ rust-version = "1.57" clap = "2.33" rustyline = "9.0" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } -datafusion = { path = "../datafusion", version = "6.0.0" } +datafusion = { path = "../datafusion" } arrow = { version = "7.0.0" } ballista = { path = "../ballista/rust/client", version = "0.6.0" } diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index bc37c7a0de20..344ad948c0cc 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -78,11 +78,18 @@ avro-rs = { version = "0.13", features = ["snappy"], optional = true } num-traits = { version = "0.2", optional = true } pyo3 = { version = "0.14", optional = true } tempfile = "3" +libloading = "0.7.3" +rustc_version = "0.4.0" +walkdir = "2.3.2" +once_cell = "1.9.0" [dev-dependencies] criterion = "0.3" doc-comment = "0.3" +[build-dependencies] +rustc_version = "0.4.0" + [[bench]] name = "aggregate_query_sql" harness = false diff --git a/datafusion/build.rs b/datafusion/build.rs new file mode 100644 index 000000000000..a38022ffdaa7 --- /dev/null +++ b/datafusion/build.rs @@ -0,0 +1,4 @@ +fn main() { + let version = rustc_version::version().unwrap(); + println!("cargo:rustc-env=RUSTC_VERSION={}", version); +} diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 89ccd7b2b938..6e542c7c6804 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -83,6 +83,9 @@ use crate::physical_plan::planner::DefaultPhysicalPlanner; use crate::physical_plan::udf::ScalarUDF; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::PhysicalPlanner; +use crate::plugin::plugin_manager::global_plugin_manager; +use crate::plugin::udf::UDFPluginManager; +use crate::plugin::PluginEnum; use crate::sql::{ parser::{DFParser, FileType}, planner::{ContextProvider, SqlToRel}, @@ -182,18 +185,40 @@ impl ExecutionContext { let runtime_env = Arc::new(RuntimeEnv::new(config.runtime_config.clone()).unwrap()); - Self { + let mut context = Self { state: Arc::new(Mutex::new(ExecutionContextState { catalog_list, scalar_functions: HashMap::new(), var_provider: HashMap::new(), aggregate_functions: HashMap::new(), - config, + config: config.clone(), execution_props: ExecutionProps::new(), object_store_registry: Arc::new(ObjectStoreRegistry::new()), runtime_env, })), + }; + + let gpm = global_plugin_manager(config.plugin_dir.as_str()); + + // register udf + let gpm_guard = gpm.lock().unwrap(); + let plugin_registrar = gpm_guard.plugin_managers.get(&PluginEnum::UDF).unwrap(); + if let Some(udf_plugin_manager) = + plugin_registrar.as_any().downcast_ref::() + { + udf_plugin_manager + .scalar_udfs + .iter() + .for_each(|(_, scalar_udf)| context.register_udf((**scalar_udf).clone())); + + udf_plugin_manager + .aggregate_udfs + .iter() + .for_each(|(_, aggregate_udf)| { + context.register_udaf((**aggregate_udf).clone()) + }); } + context } /// Creates a dataframe that will execute a SQL query. @@ -902,6 +927,8 @@ pub struct ExecutionConfig { parquet_pruning: bool, /// Runtime configurations such as memory threshold and local disk for spill pub runtime_config: RuntimeConfig, + /// plugin dir + pub plugin_dir: String, } impl Default for ExecutionConfig { @@ -937,6 +964,7 @@ impl Default for ExecutionConfig { repartition_windows: true, parquet_pruning: true, runtime_config: RuntimeConfig::default(), + plugin_dir: "".to_owned(), } } } diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index fd574d7d76ae..1d3c5250850b 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -213,6 +213,8 @@ pub mod logical_plan; pub mod optimizer; pub mod physical_optimizer; pub mod physical_plan; +/// plugin mod +pub mod plugin; pub mod prelude; pub mod scalar; pub mod sql; diff --git a/datafusion/src/physical_plan/udf.rs b/datafusion/src/physical_plan/udf.rs index af0765877c1b..55c37dea3374 100644 --- a/datafusion/src/physical_plan/udf.rs +++ b/datafusion/src/physical_plan/udf.rs @@ -18,19 +18,20 @@ //! UDF support use fmt::{Debug, Formatter}; +use std::any::Any; use std::fmt; -use arrow::datatypes::Schema; +use arrow::datatypes::{DataType, Schema}; use crate::error::Result; use crate::{logical_plan::Expr, physical_plan::PhysicalExpr}; use super::{ - functions::{ - ReturnTypeFunction, ScalarFunctionExpr, ScalarFunctionImplementation, Signature, - }, + functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature}, type_coercion::coerce, }; +use crate::physical_plan::ColumnarValue; +use arrow::record_batch::RecordBatch; use std::sync::Arc; /// Logical representation of a UDF. @@ -121,10 +122,102 @@ pub fn create_physical_expr( .map(|e| e.data_type(input_schema)) .collect::>>()?; - Ok(Arc::new(ScalarFunctionExpr::new( - &fun.name, - fun.fun.clone(), - coerced_phy_exprs, - (fun.return_type)(&coerced_exprs_types)?.as_ref(), - ))) + Ok(Arc::new(ScalarUDFExpr { + fun: fun.clone(), + name: fun.name.clone(), + args: coerced_phy_exprs.clone(), + return_type: (fun.return_type)(&coerced_exprs_types)?.as_ref().clone(), + })) } + +/// Physical expression of a UDF. +/// argo engine add +#[derive(Debug)] +pub struct ScalarUDFExpr { + fun: ScalarUDF, + name: String, + args: Vec>, + return_type: DataType, +} + +impl ScalarUDFExpr { + /// create a ScalarUDFExpr + pub fn new( + name: &str, + fun: ScalarUDF, + args: Vec>, + return_type: &DataType, + ) -> Self { + Self { + fun, + name: name.to_string(), + args, + return_type: return_type.clone(), + } + } + + /// return fun + pub fn fun(&self) -> &ScalarUDF { + &self.fun + } + + /// return name + pub fn name(&self) -> &str { + &self.name + } + + /// return args + pub fn args(&self) -> &[Arc] { + &self.args + } + + /// Data type produced by this expression + pub fn return_type(&self) -> &DataType { + &self.return_type + } +} + +impl fmt::Display for ScalarUDFExpr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{}({})", + self.name, + self.args + .iter() + .map(|e| format!("{}", e)) + .collect::>() + .join(", ") + ) + } +} + +impl PhysicalExpr for ScalarUDFExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.return_type.clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + // evaluate the arguments, if there are no arguments we'll instead pass in a null array + // indicating the batch size (as a convention) + // TODO need support zero input arguments + let inputs = self + .args + .iter() + .map(|e| e.evaluate(batch)) + .collect::>>()?; + + // evaluate the function + let fun = self.fun.fun.as_ref(); + (fun)(&inputs) + } +} // argo engine add end. diff --git a/datafusion/src/plugin/mod.rs b/datafusion/src/plugin/mod.rs new file mode 100644 index 000000000000..67d6655a2b07 --- /dev/null +++ b/datafusion/src/plugin/mod.rs @@ -0,0 +1,120 @@ +use crate::error::Result; +use crate::plugin::udf::UDFPluginManager; +use std::any::Any; +use std::env; + +/// plugin manager +pub mod plugin_manager; +/// udf plugin +pub mod udf; + +/// CARGO_PKG_VERSION +pub static CORE_VERSION: &str = env!("CARGO_PKG_VERSION"); +/// RUSTC_VERSION +pub static RUSTC_VERSION: &str = env!("RUSTC_VERSION"); + +/// Top plugin trait +pub trait Plugin { + /// Returns the plugin as [`Any`](std::any::Any) so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; +} + +/// The enum of Plugin +#[derive(PartialEq, std::cmp::Eq, std::hash::Hash, Copy, Clone)] +pub enum PluginEnum { + /// UDF/UDAF plugin + UDF, +} + +impl PluginEnum { + /// new a struct which impl the PluginRegistrar trait + pub fn init_plugin_manager(&self) -> Box { + match self { + PluginEnum::UDF => Box::new(UDFPluginManager::default()), + } + } +} + +/// Every plugin need a PluginDeclaration +#[derive(Copy, Clone)] +pub struct PluginDeclaration { + /// rustc version of the plugin. The plugin's rustc_version need same as plugin manager. + pub rustc_version: &'static str, + + /// core version of the plugin. The plugin's core_version need same as plugin manager. + pub core_version: &'static str, + + /// One of PluginEnum + pub plugin_type: unsafe extern "C" fn() -> PluginEnum, + + /// `register` is a function which impl PluginRegistrar. It will be call when plugin load. + pub register: unsafe extern "C" fn(&mut Box), +} + +/// Plugin Registrar , Every plugin need implement this trait +pub trait PluginRegistrar: Send + Sync + 'static { + /// The implementer of the plug-in needs to call this interface to report his own information to the plug-in manager + fn register_plugin(&mut self, plugin: Box) -> Result<()>; + + /// Returns the plugin registrar as [`Any`](std::any::Any) so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; +} + +/// Declare a plugin's PluginDeclaration. +/// +/// # Notes +/// +/// This works by automatically generating an `extern "C"` function with a +/// pre-defined signature and symbol name. And then generating a PluginDeclaration. +/// Therefore you will only be able to declare one plugin per library. +#[macro_export] +macro_rules! declare_plugin { + ($plugin_type:expr, $curr_plugin_type:ty, $constructor:path) => { + #[no_mangle] + pub extern "C" fn register_plugin( + registrar: &mut Box, + ) { + // make sure the constructor is the correct type. + let constructor: fn() -> $curr_plugin_type = $constructor; + let object = constructor(); + registrar.register_plugin(Box::new(object)).unwrap(); + } + + #[no_mangle] + pub extern "C" fn get_plugin_type() -> $crate::plugin::PluginEnum { + $plugin_type + } + + #[no_mangle] + pub static plugin_declaration: $crate::plugin::PluginDeclaration = + $crate::plugin::PluginDeclaration { + rustc_version: $crate::plugin::RUSTC_VERSION, + core_version: $crate::plugin::CORE_VERSION, + plugin_type: get_plugin_type, + register: register_plugin, + }; + }; +} + +/// get the plugin dir +pub fn plugin_dir() -> String { + let current_exe_dir = match env::current_exe() { + Ok(exe_path) => exe_path.display().to_string(), + Err(_e) => "".to_string(), + }; + + // If current_exe_dir contain `deps` the root dir is the parent dir + // eg: /Users/xxx/workspace/rust/rust_plugin_sty/target/debug/deps/plugins_app-067452b3ff2af70e + // the plugin dir is /Users/xxx/workspace/rust/rust_plugin_sty/target/debug + // else eg: /Users/xxx/workspace/rust/rust_plugin_sty/target/debug/plugins_app + // the plugin dir is /Users/xxx/workspace/rust/rust_plugin_sty/target/debug/ + if current_exe_dir.contains("/deps/") { + let i = current_exe_dir.find("/deps/").unwrap(); + String::from(¤t_exe_dir.as_str()[..i]) + } else { + let i = current_exe_dir.rfind('/').unwrap(); + String::from(¤t_exe_dir.as_str()[..i]) + } +} diff --git a/datafusion/src/plugin/plugin_manager.rs b/datafusion/src/plugin/plugin_manager.rs new file mode 100644 index 000000000000..a8a19e4ac8d9 --- /dev/null +++ b/datafusion/src/plugin/plugin_manager.rs @@ -0,0 +1,131 @@ +use crate::error::{DataFusionError, Result}; +use crate::plugin::{PluginDeclaration, CORE_VERSION, RUSTC_VERSION}; +use crate::plugin::{PluginEnum, PluginRegistrar}; +use libloading::Library; +use log::info; +use std::collections::HashMap; +use std::io; +use std::sync::{Arc, Mutex}; +use walkdir::{DirEntry, WalkDir}; + +use once_cell::sync::OnceCell; + +/// To prevent the library from being loaded multiple times, we use once_cell defines a Arc> +/// Because datafusion is a library, not a service, users may not need to load all plug-ins in the process. +/// So fn global_plugin_manager return Arc>. In this way, users can load the required library through the load method of GlobalPluginManager when needed +pub fn global_plugin_manager( + plugin_path: &str, +) -> &'static Arc> { + static INSTANCE: OnceCell>> = OnceCell::new(); + INSTANCE.get_or_init(move || unsafe { + let mut gpm = GlobalPluginManager::default(); + gpm.load(plugin_path).unwrap(); + Arc::new(Mutex::new(gpm)) + }) +} + +#[derive(Default)] +/// manager all plugin_type's plugin_manager +pub struct GlobalPluginManager { + /// every plugin need a plugin registrar + pub plugin_managers: HashMap>, + + /// loaded plugin files + pub plugin_files: Vec, +} + +impl GlobalPluginManager { + /// # Safety + /// find plugin file from `plugin_path` and load it . + unsafe fn load(&mut self, plugin_path: &str) -> Result<()> { + // find library file from udaf_plugin_path + info!("load plugin from dir:{}", plugin_path); + println!("load plugin from dir:{}", plugin_path); + + let plugin_files = self.get_all_plugin_files(plugin_path)?; + + for plugin_file in plugin_files { + let library = Library::new(plugin_file.path()).map_err(|e| { + DataFusionError::IoError(io::Error::new( + io::ErrorKind::Other, + format!("load library error: {}", e), + )) + })?; + + let library = Arc::new(library); + + // get a pointer to the plugin_declaration symbol. + let dec = library + .get::<*mut PluginDeclaration>(b"plugin_declaration\0") + .map_err(|e| { + DataFusionError::IoError(io::Error::new( + io::ErrorKind::Other, + format!("not found plugin_declaration in the library: {}", e), + )) + })? + .read(); + + // version checks to prevent accidental ABI incompatibilities + if dec.rustc_version != RUSTC_VERSION || dec.core_version != CORE_VERSION { + return Err(DataFusionError::IoError(io::Error::new( + io::ErrorKind::Other, + "Version mismatch", + ))); + } + + let plugin_enum = (dec.plugin_type)(); + let curr_plugin_manager = match self.plugin_managers.get_mut(&plugin_enum) { + None => { + let plugin_manager = plugin_enum.init_plugin_manager(); + self.plugin_managers.insert(plugin_enum, plugin_manager); + self.plugin_managers.get_mut(&plugin_enum).unwrap() + } + Some(manager) => manager, + }; + + (dec.register)(curr_plugin_manager); + self.plugin_files + .push(plugin_file.path().to_str().unwrap().to_string()); + } + + Ok(()) + } + + /// get all plugin file in the dir + fn get_all_plugin_files(&self, plugin_path: &str) -> io::Result> { + let mut plugin_files = Vec::new(); + for entry in WalkDir::new(plugin_path).into_iter().filter_map(|e| { + let item = e.unwrap(); + // every file only load once + if self + .plugin_files + .contains(&item.path().to_str().unwrap().to_string()) + { + return None; + } + + let file_type = item.file_type(); + if !file_type.is_file() { + return None; + } + + if let Some(path) = item.path().extension() { + if let Some(suffix) = path.to_str() { + if suffix == "dylib" || suffix == "so" || suffix == "dll" { + info!("load plugin from library file:{}", path.to_str().unwrap()); + println!( + "load plugin from library file:{}", + path.to_str().unwrap() + ); + return Some(item); + } + } + } + + return None; + }) { + plugin_files.push(entry); + } + Ok(plugin_files) + } +} diff --git a/datafusion/src/plugin/udf.rs b/datafusion/src/plugin/udf.rs new file mode 100644 index 000000000000..ffbb928fbd0f --- /dev/null +++ b/datafusion/src/plugin/udf.rs @@ -0,0 +1,88 @@ +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::udaf::AggregateUDF; +use crate::physical_plan::udf::ScalarUDF; +use crate::plugin::{Plugin, PluginRegistrar}; +use libloading::Library; +use std::any::Any; +use std::collections::HashMap; +use std::io; +use std::sync::Arc; + +/// 定义udf插件,udf的定义方需要实现该trait +pub trait UDFPlugin: Plugin { + /// get a ScalarUDF by name + fn get_scalar_udf_by_name(&self, fun_name: &str) -> Result; + + /// return all udf names in the plugin + fn udf_names(&self) -> Result>; + + /// get a aggregate udf by name + fn get_aggregate_udf_by_name(&self, fun_name: &str) -> Result; + + /// return all udaf names + fn udaf_names(&self) -> Result>; +} + +/// UDFPluginManager +#[derive(Default)] +pub struct UDFPluginManager { + /// scalar udfs + pub scalar_udfs: HashMap>, + + /// aggregate udfs + pub aggregate_udfs: HashMap>, + + /// All libraries load from the plugin dir. + pub libraries: Vec>, +} + +impl PluginRegistrar for UDFPluginManager { + fn register_plugin(&mut self, plugin: Box) -> Result<()> { + if let Some(udf_plugin) = plugin.as_any().downcast_ref::>() { + udf_plugin + .udf_names() + .unwrap() + .iter() + .try_for_each(|udf_name| { + if self.scalar_udfs.contains_key(udf_name) { + Err(DataFusionError::IoError(io::Error::new( + io::ErrorKind::Other, + format!("udf name: {} already exists", udf_name), + ))) + } else { + let scalar_udf = udf_plugin.get_scalar_udf_by_name(udf_name)?; + self.scalar_udfs + .insert(udf_name.to_string(), Arc::new(scalar_udf)); + Ok(()) + } + })?; + + udf_plugin + .udaf_names() + .unwrap() + .iter() + .try_for_each(|udaf_name| { + if self.aggregate_udfs.contains_key(udaf_name) { + Err(DataFusionError::IoError(io::Error::new( + io::ErrorKind::Other, + format!("udaf name: {} already exists", udaf_name), + ))) + } else { + let aggregate_udf = + udf_plugin.get_aggregate_udf_by_name(udaf_name)?; + self.aggregate_udfs + .insert(udaf_name.to_string(), Arc::new(aggregate_udf)); + Ok(()) + } + })?; + } + Err(DataFusionError::IoError(io::Error::new( + io::ErrorKind::Other, + format!("expected plugin type is 'dyn UDFPlugin', but it's not"), + ))) + } + + fn as_any(&self) -> &dyn Any { + self + } +}