diff --git a/examples/scala/src/main/scala/example/EvolutionWithMap.scala b/examples/scala/src/main/scala/example/EvolutionWithMap.scala new file mode 100644 index 00000000000..4b6175f8a15 --- /dev/null +++ b/examples/scala/src/main/scala/example/EvolutionWithMap.scala @@ -0,0 +1,98 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package example + +import org.apache.spark.sql.types._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.SparkSession + +object EvolutionWithMap { + def main(args: Array[String]): Unit = { + val spark = SparkSession.builder() + .appName("EvolutionWithMap") + .master("local[*]") + .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") + .getOrCreate() + + import spark.implicits._ + + val tableName = "insert_map_schema_evolution" + + try { + // Define initial schema + val initialSchema = StructType(Seq( + StructField("key", IntegerType, nullable = false), + StructField("metrics", MapType(StringType, StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", IntegerType, nullable = false) + )))) + )) + + val data = Seq( + Row(1, Map("event" -> Row(1, 1))) + ) + + val rdd = spark.sparkContext.parallelize(data) + + val initialDf = spark.createDataFrame(rdd, initialSchema) + + initialDf.write + .option("overwriteSchema", "true") + .mode("overwrite") + .format("delta") + .saveAsTable(s"$tableName") + + // Define the schema with simulteneous change in a StructField name + // And additional field in a map column + val evolvedSchema = StructType(Seq( + StructField("renamed_key", IntegerType, nullable = false), + StructField("metrics", MapType(StringType, StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", IntegerType, nullable = false), + StructField("comment", StringType, nullable = true) + )))) + )) + + val evolvedData = Seq( + Row(1, Map("event" -> Row(1, 1, "deprecated"))) + ) + + val evolvedRDD = spark.sparkContext.parallelize(evolvedData) + + val modifiedDf = spark.createDataFrame(evolvedRDD, evolvedSchema) + + // The below would fail without schema evolution for map types + modifiedDf.write + .mode("append") + .option("mergeSchema", "true") + .format("delta") + .insertInto(s"$tableName") + + spark.sql(s"SELECT * FROM $tableName").show(false) + + } finally { + + // Cleanup + spark.sql(s"DROP TABLE IF EXISTS $tableName") + + spark.stop() + } + + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala index f20e79a58ed..098a45901c1 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala @@ -69,6 +69,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap + /** * Analysis rules for Delta. Currently, these rules enable schema enforcement / evolution with * INSERT INTO. @@ -913,8 +914,8 @@ class DeltaAnalysis(session: SparkSession) } private def addCastToColumn( - attr: Attribute, - targetAttr: Attribute, + attr: NamedExpression, + targetAttr: NamedExpression, tblName: String, allowTypeWidening: Boolean): NamedExpression = { val expr = (attr.dataType, targetAttr.dataType) match { @@ -930,6 +931,8 @@ class DeltaAnalysis(session: SparkSession) // Keep the type from the query, the target schema will be updated to widen the existing // type to match it. attr + case (s: MapType, t: MapType) if s != t => + addCastsToMaps(tblName, attr, s, t, allowTypeWidening) case _ => getCastFunction(attr, targetAttr.dataType, targetAttr.name) } @@ -1047,8 +1050,7 @@ class DeltaAnalysis(session: SparkSession) } /** - * Recursively casts structs in case it contains null types. - * TODO: Support other complex types like MapType and ArrayType + * Recursively casts struct data types in case the source/target type differs. */ private def addCastsToStructs( tableName: String, @@ -1067,6 +1069,8 @@ class DeltaAnalysis(session: SparkSession) val subField = Alias(GetStructField(parent, i, Option(name)), target(i).name)( explicitMetadata = Option(metadata)) addCastsToStructs(tableName, subField, nested, t, allowTypeWidening) + // We could also handle maptype within struct here but there is restriction + // on deep nexted operations that may result in maxIteration error case o => val field = parent.qualifiedName + "." + name val targetName = parent.qualifiedName + "." + target(i).name @@ -1124,6 +1128,63 @@ class DeltaAnalysis(session: SparkSession) DeltaViewHelper.stripTempViewForMerge(plan, conf) } + /** + * Recursively casts map data types in case the key/value type differs. + */ + private def addCastsToMaps( + tableName: String, + parent: NamedExpression, + sourceMapType: MapType, + targetMapType: MapType, + allowTypeWidening: Boolean): Expression = { + + val transformedKeys = + if (sourceMapType.keyType != targetMapType.keyType) { + // Create a transformation for the keys + ArrayTransform(MapKeys(parent), { + val key = NamedLambdaVariable( + "key", sourceMapType.keyType, nullable = false) + + val targetKeyAttr = AttributeReference( + "targetKey", targetMapType.keyType, nullable = false)() + val castedKey = + addCastToColumn( + key, + targetKeyAttr, + tableName, + allowTypeWidening + ) + LambdaFunction(castedKey, Seq(key)) + }) + } else { + MapKeys(parent) + } + + val transformedValues = + if (sourceMapType.valueType != targetMapType.valueType) { + // Create a transformation for the values + ArrayTransform(MapValues(parent), { + val value = NamedLambdaVariable( + "value", sourceMapType.valueType, sourceMapType.valueContainsNull) + + val targetValueAttr = AttributeReference( + "targetValue", targetMapType.valueType, sourceMapType.valueContainsNull)() + val castedValue = + addCastToColumn( + value, + targetValueAttr, + tableName, + allowTypeWidening + ) + LambdaFunction(castedValue, Seq(value)) + }) + } else { + MapValues(parent) + } + // Create new map from transformed keys and values + MapFromArrays(transformedKeys, transformedValues) + } + /** * Verify the input plan for a SINGLE streaming query with the following: * 1. Schema location must be under checkpoint location, if not lifted by flag diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaInsertIntoTableSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaInsertIntoTableSuite.scala index 18acbc09e0f..39bb99a7213 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaInsertIntoTableSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaInsertIntoTableSuite.scala @@ -281,6 +281,219 @@ class DeltaInsertIntoSQLSuite } } + // Schema evolution for complex map type + test("insertInto schema evolution with map type - append mode: field renaming + new field") { + withTable("map_schema_evolution") { + val tableName = "map_schema_evolution" + val initialSchema = StructType(Seq( + StructField("key", IntegerType, nullable = false), + StructField("metrics", MapType(StringType, StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", IntegerType, nullable = false) + )))) + )) + + val initialData = Seq( + Row(1, Map("event" -> Row(1, 1))) + ) + + val initialRdd = spark.sparkContext.parallelize(initialData) + val initialDf = spark.createDataFrame(initialRdd, initialSchema) + + // Write initial data + initialDf.write + .option("overwriteSchema", "true") + .mode("overwrite") + .format("delta") + .saveAsTable(tableName) + + // Evolved schema with field renamed and additional field in map struct + val evolvedSchema = StructType(Seq( + StructField("renamed_key", IntegerType, nullable = false), + StructField("metrics", MapType(StringType, StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", IntegerType, nullable = false), + StructField("comment", StringType, nullable = true) + )))) + )) + + val evolvedData = Seq( + Row(1, Map("event" -> Row(1, 1, "deprecated"))) + ) + + val evolvedRdd = spark.sparkContext.parallelize(evolvedData) + val evolvedDf = spark.createDataFrame(evolvedRdd, evolvedSchema) + + // insert data without schema evolution + assert(intercept[AnalysisException] { + evolvedDf.write + .mode("append") + .format("delta") + .insertInto(tableName) + }.getMessage.contains("A schema mismatch detected when writing to the Delta table") + ) + + // insert data with schema evolution + withSQLConf("spark.databricks.delta.schema.autoMerge.enabled" -> "true") { + evolvedDf.write + .mode("append") + .format("delta") + .insertInto(tableName) + + val result = spark.sql(s"SELECT * FROM $tableName").collect() + val expected = Seq( + Row(1, Map("event" -> Row(1, 1, null))), + Row(1, Map("event" -> Row(1, 1, "deprecated"))) + ) + + assert(result.toSet == expected.toSet) + } + } + } + + test("not enough column in source to insert in nested map types") { + withTable("source", "target") { + sql( + """CREATE TABLE source ( + | id INT, + | metrics MAP> + |) USING delta""".stripMargin) + + sql( + """CREATE TABLE target ( + | id INT, + | metrics MAP> + |) USING delta""".stripMargin) + + sql("INSERT INTO source VALUES (1, map('event', struct(1, 1)))") + + val e = intercept[AnalysisException] { + sql("INSERT INTO target SELECT * FROM source") + } + assert(e.getMessage.contains("not enough nested fields in value")) + } + } + + + test("more columns in source to insert in nested map types") { + withTable("source", "target") { + sql( + """CREATE TABLE source ( + | id INT, + | metrics MAP> + |) USING delta""".stripMargin) + + sql( + """CREATE TABLE target ( + | id INT, + | metrics MAP> + |) USING delta""".stripMargin) + + sql("INSERT INTO source VALUES (1, map('event', struct(1, 1, 'deprecated')))") + + val e = intercept[AnalysisException] { + sql("INSERT INTO target SELECT * FROM source") + } + assert(e.getMessage.contains("mergeSchema")) + + withSQLConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE.key -> "true") { + sql("INSERT INTO target SELECT * FROM source") + val result = spark.sql(s"SELECT * FROM source").collect() + val expected = Seq( + Row(1, Map("event" -> Row(1, 1, "deprecated"))) + ) + + assert(result.toSet == expected.toSet) + } + } + } + + test("more columns in source to insert in nested 2-level deep map types") { + withTable("source", "target") { + sql( + """CREATE TABLE source ( + | id INT, + | metrics MAP>> + |) USING delta""".stripMargin) + + sql( + """CREATE TABLE target ( + | id INT, + | metrics MAP>> + |) USING delta""".stripMargin) + + sql( + """INSERT INTO source VALUES + | (1, map('event', map('subEvent', struct(1, 1, 'deprecated')))) + """.stripMargin) + + val e = intercept[AnalysisException] { + sql("INSERT INTO target SELECT * FROM source") + } + assert(e.getMessage.contains("mergeSchema")) + + withSQLConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE.key -> "true") { + sql("INSERT INTO target SELECT * FROM source") + val result = spark.sql(s"SELECT * FROM source").collect() + val expected = Seq( + Row(1, Map("event" -> Map("subEvent" -> Row(1, 1, "deprecated")))) + ) + + assert(result.toSet == expected.toSet) + } + } + } + + + test("insert map type with different data type in key") { + withTable("source", "target") { + sql( + """CREATE TABLE source ( + | id INT, + | metrics MAP> + |) USING delta""".stripMargin) + + sql( + """CREATE TABLE target ( + | id INT, + | metrics MAP> + |) USING delta""".stripMargin) + + sql("INSERT INTO source VALUES (1, map('1', struct(2, 3)))") + + sql("INSERT INTO target SELECT * FROM source") + + val result = spark.sql("SELECT * FROM target").collect() + val expected = Seq(Row(1, Map(1 -> Row(2, 3)))) + assert(result.toSet == expected.toSet) + } + } + + test("insert map type with different data type in value") { + withTable("source", "target") { + sql( + """CREATE TABLE source ( + | id INT, + | metrics MAP> + |) USING delta""".stripMargin) + + sql( + """CREATE TABLE target ( + | id INT, + | metrics MAP> + |) USING delta""".stripMargin) + + sql("INSERT INTO source VALUES (1, map('m1', struct(2, 3L)))") + + sql("INSERT INTO target SELECT * FROM source") + + val result = spark.sql("SELECT * FROM target").collect() + val expected = Seq(Row(1, Map("m1" -> Row(2, 3)))) + assert(result.toSet == expected.toSet) + } + } + + def runInsertOverwrite( sourceSchema: String, sourceRecord: String, diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/typewidening/TypeWideningInsertSchemaEvolutionSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/typewidening/TypeWideningInsertSchemaEvolutionSuite.scala index 55cd149a72a..d38754cc390 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/typewidening/TypeWideningInsertSchemaEvolutionSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/typewidening/TypeWideningInsertSchemaEvolutionSuite.scala @@ -297,26 +297,7 @@ trait TypeWideningInsertSchemaEvolutionTests metadata = typeWideningMetadata(version = 1, from = ShortType, to = IntegerType))))) ) - // The next two tests document inconsistencies when handling maps. Using SQL or INSERT by position - // doesn't allow type evolution but using dataframe INSERT by name does. - testInserts("nested struct type evolution with field upcast in map")( - initialData = TestData( - "key int, m map>", - Seq("""{ "key": 1, "m": { "a": { "x": 1, "y": 2 } } }""")), - partitionBy = Seq("key"), - overwriteWhere = "key" -> 1, - insertData = TestData( - "key int, m map>", - Seq("""{ "key": 1, "m": { "a": { "x": 3, "y": 4 } } }""")), - expectedResult = ExpectedResult.Success(new StructType() - .add("key", IntegerType) - // Type evolution wasn't applied in the map. - .add("m", MapType(StringType, new StructType() - .add("x", IntegerType) - .add("y", ShortType)))), - excludeInserts = insertsDataframe.intersect(insertsByName) - ) - + // maps now allow type evolution for INSERT by position and name in SQL and dataframe. testInserts("nested struct type evolution with field upcast in map")( initialData = TestData( "key int, m map>", @@ -333,6 +314,5 @@ trait TypeWideningInsertSchemaEvolutionTests .add("x", IntegerType) .add("y", IntegerType, nullable = true, metadata = typeWideningMetadata(version = 1, from = ShortType, to = IntegerType))))), - includeInserts = insertsDataframe.intersect(insertsByName) ) }