diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index f8504a487a66..b5196c086638 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -31,9 +31,9 @@ use arrow_schema::DataType; use datafusion_common::file_options::StatementOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ - not_impl_err, plan_datafusion_err, plan_err, unqualified_field_not_found, Column, - Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, ExprSchema, - OwnedTableReference, Result, SchemaReference, TableReference, ToDFSchema, + not_impl_err, plan_datafusion_err, plan_err, unqualified_field_not_found, + Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, + Result, SchemaReference, TableReference, ToDFSchema, }; use datafusion_expr::dml::{CopyOptions, CopyTo}; use datafusion_expr::expr::Placeholder; @@ -969,12 +969,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { table_name.clone(), &arrow_schema, )?); - let values = table_schema.fields().iter().map(|f| { - ( - f.name().clone(), - ast::Expr::Identifier(ast::Ident::from(f.name().as_str())), - ) - }); // Overwrite with assignment expressions let mut planner_context = PlannerContext::new(); @@ -992,11 +986,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }) .collect::>>()?; - let values = values - .into_iter() - .map(|(k, v)| { - let val = assign_map.remove(&k).unwrap_or(v); - (k, val) + let values_and_types = table_schema + .fields() + .iter() + .map(|f| { + let col_name = f.name(); + let val = assign_map.remove(col_name).unwrap_or_else(|| { + ast::Expr::Identifier(ast::Ident::from(col_name.as_str())) + }); + (col_name, val, f.data_type()) }) .collect::>(); @@ -1026,25 +1024,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Projection let mut exprs = vec![]; - for (col_name, expr) in values.into_iter() { + for (col_name, expr, dt) in values_and_types.into_iter() { let expr = self.sql_to_expr(expr, &table_schema, &mut planner_context)?; let expr = match expr { datafusion_expr::Expr::Placeholder(Placeholder { ref id, ref data_type, }) => match data_type { - None => { - let dt = table_schema.data_type(&Column::from_name(&col_name))?; - datafusion_expr::Expr::Placeholder(Placeholder::new( - id.clone(), - Some(dt.clone()), - )) - } + None => datafusion_expr::Expr::Placeholder(Placeholder::new( + id.clone(), + Some(dt.clone()), + )), Some(_) => expr, }, _ => expr, }; - let expr = expr.alias(col_name); + let expr = expr.cast_to(dt, source.schema())?.alias(col_name); exprs.push(expr); } let source = project(source, exprs)?; diff --git a/datafusion/sqllogictest/test_files/update.slt b/datafusion/sqllogictest/test_files/update.slt new file mode 100644 index 000000000000..4542a262390c --- /dev/null +++ b/datafusion/sqllogictest/test_files/update.slt @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Update Tests +########## + +statement ok +create table t1(a int, b varchar, c double, d int); + +# Turn off the optimizer to make the logical plan closer to the initial one +statement ok +set datafusion.optimizer.max_passes = 0; + +query TT +explain update t1 set a=1, b=2, c=3.0, d=NULL; +---- +logical_plan +Dml: op=[Update] table=[t1] +--Projection: CAST(Int64(1) AS Int32) AS a, CAST(Int64(2) AS Utf8) AS b, Float64(3) AS c, CAST(NULL AS Int32) AS d +----TableScan: t1 + +query TT +explain update t1 set a=c+1, b=a, c=c+1.0, d=b; +---- +logical_plan +Dml: op=[Update] table=[t1] +--Projection: CAST(t1.c + CAST(Int64(1) AS Float64) AS Int32) AS a, CAST(t1.a AS Utf8) AS b, t1.c + Float64(1) AS c, CAST(t1.b AS Int32) AS d +----TableScan: t1