From 6348cab3e0a76b6efa9f7605ee7379a6f6378fc5 Mon Sep 17 00:00:00 2001 From: Ruslan Iushchenko Date: Thu, 19 Dec 2024 10:17:21 +0100 Subject: [PATCH] #731 Add an option to copy data type when copying metadata. --- .../cobrix/spark/cobol/utils/SparkUtils.scala | 35 ++++++- .../spark/cobol/utils/SparkUtilsSuite.scala | 97 +++++++++++++++++++ 2 files changed, 127 insertions(+), 5 deletions(-) diff --git a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala index dd7efdbc..51c2f41e 100644 --- a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala +++ b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala @@ -48,6 +48,21 @@ object SparkUtils extends Logging { allExecutors.filter(!_.equals(driverHost)).toList.distinct } + /** + * Returns true if Spark Data type is a primitive data type. + * + * @param dataType Stark data type + * @return true if the data type is primitive. + */ + def isPrimitive(dataType: DataType): Boolean = { + dataType match { + case _: ArrayType => false + case _: StructType => false + case _: MapType => false + case _ => true + } + } + /** * Given an instance of DataFrame returns a dataframe with flattened schema. * All nested structures are flattened and arrays are projected as columns. @@ -248,12 +263,14 @@ object SparkUtils extends Logging { * @param schemaTo Schema to copy metadata to. * @param overwrite If true, the metadata of schemaTo is not retained * @param sourcePreferred If true, schemaFrom metadata is used on conflicts, schemaTo otherwise. + * @param copyDataType If true, data type is copied as well. This is limited to primitive data types. * @return Same schema as schemaTo with metadata from schemaFrom. */ def copyMetadata(schemaFrom: StructType, schemaTo: StructType, overwrite: Boolean = false, - sourcePreferred: Boolean = false): StructType = { + sourcePreferred: Boolean = false, + copyDataType: Boolean = false): StructType = { def joinMetadata(from: Metadata, to: Metadata): Metadata = { val newMetadataMerged = new MetadataBuilder @@ -273,12 +290,16 @@ object SparkUtils extends Logging { ar.elementType match { case st: StructType if fieldFrom.dataType.isInstanceOf[ArrayType] && fieldFrom.dataType.asInstanceOf[ArrayType].elementType.isInstanceOf[StructType] => val innerStructFrom = fieldFrom.dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType] - val newDataType = StructType(copyMetadata(innerStructFrom, st).fields) + val newDataType = StructType(copyMetadata(innerStructFrom, st, overwrite, sourcePreferred, copyDataType).fields) ArrayType(newDataType, ar.containsNull) case at: ArrayType => processArray(at, fieldFrom, fieldTo) case p => - ArrayType(p, ar.containsNull) + if (copyDataType && fieldFrom.dataType.isInstanceOf[ArrayType] && isPrimitive(fieldFrom.dataType.asInstanceOf[ArrayType].elementType)) { + ArrayType(fieldFrom.dataType.asInstanceOf[ArrayType].elementType, ar.containsNull) + } else { + ArrayType(p, ar.containsNull) + } } } @@ -295,13 +316,17 @@ object SparkUtils extends Logging { fieldTo.dataType match { case st: StructType if fieldFrom.dataType.isInstanceOf[StructType] => - val newDataType = StructType(copyMetadata(fieldFrom.dataType.asInstanceOf[StructType], st).fields) + val newDataType = StructType(copyMetadata(fieldFrom.dataType.asInstanceOf[StructType], st, overwrite, sourcePreferred, copyDataType).fields) fieldTo.copy(dataType = newDataType, metadata = newMetadata) case at: ArrayType => val newType = processArray(at, fieldFrom, fieldTo) fieldTo.copy(dataType = newType, metadata = newMetadata) case _ => - fieldTo.copy(metadata = newMetadata) + if (copyDataType && isPrimitive(fieldFrom.dataType)) { + fieldTo.copy(dataType = fieldFrom.dataType, metadata = newMetadata) + } else { + fieldTo.copy(metadata = newMetadata) + } } case None => fieldTo diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala index 9bc0a787..51c8041b 100644 --- a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala +++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala @@ -39,6 +39,24 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt """[{"id":4,"legs":[]}]""" :: """[{"id":5,"legs":null}]""" :: Nil + test("IsPrimitive should work as expected") { + assert(SparkUtils.isPrimitive(BooleanType)) + assert(SparkUtils.isPrimitive(ByteType)) + assert(SparkUtils.isPrimitive(ShortType)) + assert(SparkUtils.isPrimitive(IntegerType)) + assert(SparkUtils.isPrimitive(LongType)) + assert(SparkUtils.isPrimitive(FloatType)) + assert(SparkUtils.isPrimitive(DoubleType)) + assert(SparkUtils.isPrimitive(DecimalType(10, 2))) + assert(SparkUtils.isPrimitive(StringType)) + assert(SparkUtils.isPrimitive(BinaryType)) + assert(SparkUtils.isPrimitive(DateType)) + assert(SparkUtils.isPrimitive(TimestampType)) + assert(!SparkUtils.isPrimitive(ArrayType(StringType))) + assert(!SparkUtils.isPrimitive(StructType(Seq(StructField("a", StringType))))) + assert(!SparkUtils.isPrimitive(MapType(StringType, StringType))) + } + test("Test schema flattening of multiple nested structure") { val expectedOrigSchema = """root @@ -626,6 +644,85 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt assert(newDf.schema.fields.head.metadata.getLong("maxLength") == 120) } + test("copyMetadata should copy primitive data types when it is enabled") { + val schemaFrom = StructType( + Seq( + StructField("int_field1", IntegerType, nullable = true, metadata = new MetadataBuilder().putString("comment", "Test1").build()), + StructField("string_field", StringType, nullable = true, metadata = new MetadataBuilder().putLong("maxLength", 120).build()), + StructField("int_field2", StructType( + Seq( + StructField("int_field20", IntegerType, nullable = true, metadata = new MetadataBuilder().putString("comment", "Test20").build()) + ) + ), nullable = true), + StructField("struct_field2", StructType( + Seq( + StructField("int_field3", IntegerType, nullable = true, metadata = new MetadataBuilder().putString("comment", "Test3").build()) + ) + ), nullable = true), + StructField("array_string", ArrayType(StringType), nullable = true, metadata = new MetadataBuilder().putLong("maxLength", 60).build()), + StructField("array_struct", ArrayType(StructType( + Seq( + StructField("int_field4", IntegerType, nullable = true, metadata = new MetadataBuilder().putString("comment", "Test4").build()) + ) + )), nullable = true) + ) + ) + + val schemaTo = StructType( + Seq( + StructField("int_field1", BooleanType, nullable = true), + StructField("string_field", IntegerType, nullable = true), + StructField("int_field2", IntegerType, nullable = true), + StructField("struct_field2", StructType( + Seq( + StructField("int_field3", BooleanType, nullable = true) + ) + ), nullable = true), + StructField("array_string", ArrayType(IntegerType), nullable = true), + StructField("array_struct", ArrayType(StructType( + Seq( + StructField("int_field4", StringType, nullable = true) + ) + )), nullable = true) + ) + ) + + val schemaWithMetadata = SparkUtils.copyMetadata(schemaFrom, schemaTo, copyDataType = true) + val fields = schemaWithMetadata.fields + + // Ensure data types are copied + // Expected schema: + // root + // |-- int_field1: boolean (nullable = true) + // |-- string_field: integer (nullable = true) + // |-- int_field2: integer (nullable = true) + // |-- struct_field2: struct (nullable = true) + // | |-- int_field3: boolean (nullable = true) + // |-- array_string: array (nullable = true) + // | |-- element: integer (containsNull = true) + // |-- array_struct: array (nullable = true) + // | |-- element: struct (containsNull = true) + // | | |-- int_field4: string (nullable = true) + assert(fields.head.dataType == IntegerType) + assert(fields(1).dataType == StringType) + assert(fields(2).dataType == IntegerType) + assert(fields(3).dataType.isInstanceOf[StructType]) + assert(fields(4).dataType.isInstanceOf[ArrayType]) + assert(fields(5).dataType.isInstanceOf[ArrayType]) + + assert(fields(3).dataType.asInstanceOf[StructType].fields.head.dataType == IntegerType) + assert(fields(4).dataType.asInstanceOf[ArrayType].elementType == StringType) + assert(fields(5).dataType.asInstanceOf[ArrayType].elementType.isInstanceOf[StructType]) + assert(fields(5).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields.head.dataType == IntegerType) + + // Ensure metadata is copied + assert(fields.head.metadata.getString("comment") == "Test1") + assert(fields(1).metadata.getLong("maxLength") == 120) + assert(fields(3).dataType.asInstanceOf[StructType].fields.head.metadata.getString("comment") == "Test3") + assert(fields(4).metadata.getLong("maxLength") == 60) + assert(fields(5).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields.head.metadata.getString("comment") == "Test4") + } + test("copyMetadata should retain metadata on conflicts by default") { val df1 = List(1, 2, 3).toDF("col1") val df2 = List(1, 2, 3).toDF("col1")