Skip to content

Commit

Permalink
ARROW-9837: [Rust][DataFusion] Added provider for variable
Browse files Browse the repository at this point in the history
Select @@Version, @name;

@@Version is a variable, and if we want to get its value, we should get it from outside the system.
@@Version is a system variable, @name is user defined variable.

Closes apache#8135 from wqc200/master_variable

Authored-by: wqc200 <[email protected]>
Signed-off-by: Andy Grove <[email protected]>
  • Loading branch information
wqc200 authored and GeorgeAp committed Jun 7, 2021
1 parent 0ff4499 commit 784c05f
Show file tree
Hide file tree
Showing 9 changed files with 239 additions and 9 deletions.
54 changes: 53 additions & 1 deletion rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ use crate::sql::{
parser::{DFParser, FileType},
planner::{SchemaProvider, SqlToRel},
};
use crate::variable::{VarProvider, VarType};

/// ExecutionContext is the main interface for executing queries with DataFusion. The context
/// provides the following functionality:
Expand Down Expand Up @@ -107,6 +108,7 @@ impl ExecutionContext {
state: ExecutionContextState {
datasources: HashMap::new(),
scalar_functions: HashMap::new(),
var_provider: HashMap::new(),
config,
},
};
Expand Down Expand Up @@ -177,6 +179,15 @@ impl ExecutionContext {
Ok(query_planner.statement_to_plan(&statements[0])?)
}

/// Register variable
pub fn register_variable(
&mut self,
variable_type: VarType,
provider: Arc<dyn VarProvider + Send + Sync>,
) {
self.state.var_provider.insert(variable_type, provider);
}

/// Register a scalar UDF
pub fn register_udf(&mut self, f: ScalarUDF) {
self.state
Expand Down Expand Up @@ -460,6 +471,8 @@ pub struct ExecutionContextState {
pub datasources: HashMap<String, Arc<dyn TableProvider + Send + Sync>>,
/// Scalar functions that are registered with the context
pub scalar_functions: HashMap<String, Arc<ScalarUDF>>,
/// Variable provider that are registered with the context
pub var_provider: HashMap<VarType, Arc<dyn VarProvider + Send + Sync>>,
/// Context configuration
pub config: ExecutionConfig,
}
Expand Down Expand Up @@ -504,7 +517,8 @@ mod tests {
use crate::logical_plan::{aggregate_expr, col, create_udf};
use crate::physical_plan::functions::ScalarFunctionImplementation;
use crate::test;
use arrow::array::{ArrayRef, Int32Array};
use crate::variable::VarType;
use arrow::array::{ArrayRef, Int32Array, StringArray};
use arrow::compute::add;
use std::fs::File;
use std::{io::prelude::*, sync::Mutex};
Expand All @@ -530,6 +544,44 @@ mod tests {
Ok(())
}

#[test]
fn create_variable_expr() -> Result<()> {
let tmp_dir = TempDir::new("variable_expr")?;
let partition_count = 4;
let mut ctx = create_ctx(&tmp_dir, partition_count)?;

let variable_provider = test::variable::SystemVar::new();
ctx.register_variable(VarType::System, Arc::new(variable_provider));
let variable_provider = test::variable::UserDefinedVar::new();
ctx.register_variable(VarType::UserDefined, Arc::new(variable_provider));

let provider = test::create_table_dual();
ctx.register_table("dual", provider);

let results = collect(&mut ctx, "SELECT @@version, @name FROM dual")?;

let batch = &results[0];
assert_eq!(2, batch.num_columns());
assert_eq!(1, batch.num_rows());
assert_eq!(field_names(batch), vec!["@@version", "@name"]);

let version = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.expect("failed to cast version");
assert_eq!(version.value(0), "system-var-@@version");

let name = batch
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.expect("failed to cast name");
assert_eq!(name.value(0), "user-defined-var-@name");

Ok(())
}

#[test]
fn parallel_query_with_filter() -> Result<()> {
let tmp_dir = TempDir::new("parallel_query_with_filter")?;
Expand Down
1 change: 1 addition & 0 deletions rust/datafusion/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ pub mod optimizer;
pub mod physical_plan;
pub mod prelude;
pub mod sql;
pub mod variable;

#[cfg(test)]
pub mod test;
6 changes: 6 additions & 0 deletions rust/datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ fn create_name(e: &Expr, input_schema: &Schema) -> Result<String> {
match e {
Expr::Alias(_, name) => Ok(name.clone()),
Expr::Column(name) => Ok(name.clone()),
Expr::ScalarVariable(variable_names) => Ok(variable_names.join(".")),
Expr::Literal(value) => Ok(format!("{:?}", value)),
Expr::BinaryExpr { left, op, right } => {
let left = create_name(left, input_schema)?;
Expand Down Expand Up @@ -235,6 +236,8 @@ pub enum Expr {
Alias(Box<Expr>, String),
/// column of a table scan
Column(String),
/// scalar variable like @@version
ScalarVariable(Vec<String>),
/// literal value
Literal(ScalarValue),
/// binary expression e.g. "age > 21"
Expand Down Expand Up @@ -301,6 +304,7 @@ impl Expr {
match self {
Expr::Alias(expr, _) => expr.get_type(schema),
Expr::Column(name) => Ok(schema.field_with_name(name)?.data_type().clone()),
Expr::ScalarVariable(_) => Ok(DataType::Utf8),
Expr::Literal(l) => l.get_datatype(),
Expr::Cast { data_type, .. } => Ok(data_type.clone()),
Expr::ScalarUDF { fun, args } => {
Expand Down Expand Up @@ -388,6 +392,7 @@ impl Expr {
ScalarValue::Null => Ok(true),
_ => Ok(false),
},
Expr::ScalarVariable(_) => Ok(true),
Expr::Cast { expr, .. } => expr.nullable(input_schema),
Expr::ScalarFunction { .. } => Ok(true),
Expr::ScalarUDF { .. } => Ok(true),
Expand Down Expand Up @@ -713,6 +718,7 @@ impl fmt::Debug for Expr {
match self {
Expr::Alias(expr, alias) => write!(f, "{:?} AS {}", expr, alias),
Expr::Column(name) => write!(f, "#{}", name),
Expr::ScalarVariable(var_names) => write!(f, "{}", var_names.join(".")),
Expr::Literal(v) => write!(f, "{:?}", v),
Expr::Cast { expr, data_type } => {
write!(f, "CAST({:?} AS {:?})", expr, data_type)
Expand Down
6 changes: 6 additions & 0 deletions rust/datafusion/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ pub fn expr_to_column_names(expr: &Expr, accum: &mut HashSet<String>) -> Result<
accum.insert(name.clone());
Ok(())
}
Expr::ScalarVariable(var_names) => {
accum.insert(var_names.join("."));
Ok(())
}
Expr::Literal(_) => {
// not needed
Ok(())
Expand Down Expand Up @@ -206,6 +210,7 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<&Expr>> {
Expr::Column(_) => Ok(vec![]),
Expr::Alias(expr, ..) => Ok(vec![expr]),
Expr::Literal(_) => Ok(vec![]),
Expr::ScalarVariable(_) => Ok(vec![]),
Expr::Not(expr) => Ok(vec![expr]),
Expr::Sort { expr, .. } => Ok(vec![expr]),
Expr::Wildcard { .. } => Err(ExecutionError::General(
Expand Down Expand Up @@ -248,6 +253,7 @@ pub fn rewrite_expression(expr: &Expr, expressions: &Vec<Expr>) -> Result<Expr>
Expr::Not(_) => Ok(Expr::Not(Box::new(expressions[0].clone()))),
Expr::Column(_) => Ok(expr.clone()),
Expr::Literal(_) => Ok(expr.clone()),
Expr::ScalarVariable(_) => Ok(expr.clone()),
Expr::Sort {
asc, nulls_first, ..
} => Ok(Expr::Sort {
Expand Down
27 changes: 27 additions & 0 deletions rust/datafusion/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ use crate::physical_plan::sort::SortExec;
use crate::physical_plan::udf;
use crate::physical_plan::{expressions, Distribution};
use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, PhysicalPlanner};
use crate::variable::VarType;
use arrow::compute::SortOptions;
use arrow::datatypes::Schema;

Expand Down Expand Up @@ -403,6 +404,31 @@ impl DefaultPhysicalPlanner {
Ok(Arc::new(Column::new(name)))
}
Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))),
Expr::ScalarVariable(variable_names) => {
if &variable_names[0][0..2] == "@@" {
match ctx_state.var_provider.get(&VarType::System) {
Some(provider) => {
let scalar_value =
provider.get_value(variable_names.clone())?;
Ok(Arc::new(Literal::new(scalar_value)))
}
_ => Err(ExecutionError::General(format!(
"No system variable provider found"
))),
}
} else {
match ctx_state.var_provider.get(&VarType::UserDefined) {
Some(provider) => {
let scalar_value =
provider.get_value(variable_names.clone())?;
Ok(Arc::new(Literal::new(scalar_value)))
}
_ => Err(ExecutionError::General(format!(
"No user defined variable provider found"
))),
}
}
}
Expr::BinaryExpr { left, op, right } => {
let lhs = self.create_physical_expr(left, input_schema, ctx_state)?;
let rhs = self.create_physical_expr(right, input_schema, ctx_state)?;
Expand Down Expand Up @@ -549,6 +575,7 @@ mod tests {
ExecutionContextState {
datasources: HashMap::new(),
scalar_functions: HashMap::new(),
var_provider: HashMap::new(),
config: ExecutionConfig::new(),
}
}
Expand Down
40 changes: 32 additions & 8 deletions rust/datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,14 +406,38 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> {
},
SQLExpr::Value(Value::SingleQuotedString(ref s)) => Ok(lit(s.clone())),

SQLExpr::Identifier(ref id) => match schema.field_with_name(&id.value) {
Ok(field) => Ok(Expr::Column(field.name().clone())),
Err(_) => Err(ExecutionError::ExecutionError(format!(
"Invalid identifier '{}' for schema {}",
id,
schema.to_string()
))),
},
SQLExpr::Identifier(ref id) => {
if &id.value[0..1] == "@" {
let var_names = vec![id.value.clone()];
Ok(Expr::ScalarVariable(var_names))
} else {
match schema.field_with_name(&id.value) {
Ok(field) => Ok(Expr::Column(field.name().clone())),
Err(_) => Err(ExecutionError::ExecutionError(format!(
"Invalid identifier '{}' for schema {}",
id,
schema.to_string()
))),
}
}
}

SQLExpr::CompoundIdentifier(ids) => {
let mut var_names = vec![];
for i in 0..ids.len() {
let id = ids[i].clone();
var_names.push(id.value);
}
if &var_names[0][0..1] == "@" {
Ok(Expr::ScalarVariable(var_names))
} else {
Err(ExecutionError::ExecutionError(format!(
"Invalid compound identifier '{:?}' for schema {}",
var_names,
schema.to_string()
)))
}
}

SQLExpr::Wildcard => Ok(Expr::Wildcard),

Expand Down
20 changes: 20 additions & 0 deletions rust/datafusion/src/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! Common unit test utility methods

use crate::datasource::{MemTable, TableProvider};
use crate::error::Result;
use crate::execution::context::ExecutionContext;
use crate::logical_plan::{Expr, LogicalPlan, LogicalPlanBuilder};
Expand All @@ -31,6 +32,23 @@ use std::io::{BufReader, BufWriter};
use std::sync::Arc;
use tempdir::TempDir;

pub fn create_table_dual() -> Box<dyn TableProvider + Send + Sync> {
let dual_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
]));
let batch = RecordBatch::try_new(
dual_schema.clone(),
vec![
Arc::new(array::Int32Array::from(vec![1])),
Arc::new(array::StringArray::from(vec!["a"])),
],
)
.unwrap();
let provider = MemTable::new(dual_schema.clone(), vec![vec![batch.clone()]]).unwrap();
Box::new(provider)
}

/// Get the value of the ARROW_TEST_DATA environment variable
pub fn arrow_testdata_path() -> String {
env::var("ARROW_TEST_DATA").expect("ARROW_TEST_DATA not defined")
Expand Down Expand Up @@ -237,3 +255,5 @@ pub fn min(expr: Expr) -> Expr {
args: vec![expr],
}
}

pub mod variable;
58 changes: 58 additions & 0 deletions rust/datafusion/src/test/variable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! System variable provider

use crate::error::Result;
use crate::logical_plan::ScalarValue;
use crate::variable::VarProvider;

/// System variable
pub struct SystemVar {}

impl SystemVar {
/// new system variable
pub fn new() -> Self {
Self {}
}
}

impl VarProvider for SystemVar {
/// get system variable value
fn get_value(&self, var_names: Vec<String>) -> Result<ScalarValue> {
let s = format!("{}-{}", "system-var".to_string(), var_names.concat());
Ok(ScalarValue::Utf8(s))
}
}

/// user defined variable
pub struct UserDefinedVar {}

impl UserDefinedVar {
/// new user defined variable
pub fn new() -> Self {
Self {}
}
}

impl VarProvider for UserDefinedVar {
/// Get user defined variable value
fn get_value(&self, var_names: Vec<String>) -> Result<ScalarValue> {
let s = format!("{}-{}", "user-defined-var".to_string(), var_names.concat());
Ok(ScalarValue::Utf8(s))
}
}
Loading

0 comments on commit 784c05f

Please sign in to comment.