From 1417d555d45dd0d01c685c6e47ae98a7f37c8a05 Mon Sep 17 00:00:00 2001 From: Ruslan Iushchenko Date: Wed, 31 Jul 2024 07:38:28 +0200 Subject: [PATCH] #697 Add conflict resolution logic to `SparkUtils.copyMetadata`. --- .../cobrix/spark/cobol/utils/SparkUtils.scala | 21 +++++--- .../spark/cobol/utils/SparkUtilsSuite.scala | 52 +++++++++++++++++++ 2 files changed, 67 insertions(+), 6 deletions(-) diff --git a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala index 233527918..7a58d42b6 100644 --- a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala +++ b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala @@ -243,17 +243,26 @@ object SparkUtils extends Logging { /** * Copies metadata from one schema to another as long as names and data types are the same. * - * @param schemaFrom Schema to copy metadata from. - * @param schemaTo Schema to copy metadata to. - * @param overwrite If true, the metadata of schemaTo is not retained + * @param schemaFrom Schema to copy metadata from. + * @param schemaTo Schema to copy metadata to. + * @param overwrite If true, the metadata of schemaTo is not retained + * @param sourcePreferred If true, schemaFrom metadata is used on conflicts, schemaTo otherwise. * @return Same schema as schemaTo with metadata from schemaFrom. */ - def copyMetadata(schemaFrom: StructType, schemaTo: StructType, overwrite: Boolean = false): StructType = { + def copyMetadata(schemaFrom: StructType, + schemaTo: StructType, + overwrite: Boolean = false, + sourcePreferred: Boolean = false): StructType = { def joinMetadata(from: Metadata, to: Metadata): Metadata = { val newMetadataMerged = new MetadataBuilder - newMetadataMerged.withMetadata(from) - newMetadataMerged.withMetadata(to) + if (sourcePreferred) { + newMetadataMerged.withMetadata(to) + newMetadataMerged.withMetadata(from) + } else { + newMetadataMerged.withMetadata(from) + newMetadataMerged.withMetadata(to) + } newMetadataMerged.build() } diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala index 81a192fc1..9bc0a7870 100644 --- a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala +++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala @@ -626,6 +626,58 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt assert(newDf.schema.fields.head.metadata.getLong("maxLength") == 120) } + test("copyMetadata should retain metadata on conflicts by default") { + val df1 = List(1, 2, 3).toDF("col1") + val df2 = List(1, 2, 3).toDF("col1") + + val metadata1 = new MetadataBuilder() + metadata1.putString("comment", "Test") + metadata1.putLong("maxLength", 100) + + val metadata2 = new MetadataBuilder() + metadata2.putLong("maxLength", 120) + metadata2.putLong("newMetadata", 180) + + val schema1WithMetadata = StructType(Seq(df1.schema.fields.head.copy(metadata = metadata1.build()))) + val schema2WithMetadata = StructType(Seq(df2.schema.fields.head.copy(metadata = metadata2.build()))) + + val df1WithMetadata = spark.createDataFrame(df2.rdd, schema1WithMetadata) + + val schemaWithMetadata = SparkUtils.copyMetadata(df1WithMetadata.schema, schema2WithMetadata) + + val newDf = spark.createDataFrame(df2.rdd, schemaWithMetadata) + + assert(newDf.schema.fields.head.metadata.getString("comment") == "Test") + assert(newDf.schema.fields.head.metadata.getLong("maxLength") == 120) + assert(newDf.schema.fields.head.metadata.getLong("newMetadata") == 180) + } + + test("copyMetadata should overwrite metadata on conflicts when sourcePreferred=true") { + val df1 = List(1, 2, 3).toDF("col1") + val df2 = List(1, 2, 3).toDF("col1") + + val metadata1 = new MetadataBuilder() + metadata1.putString("comment", "Test") + metadata1.putLong("maxLength", 100) + + val metadata2 = new MetadataBuilder() + metadata2.putLong("maxLength", 120) + metadata2.putLong("newMetadata", 180) + + val schema1WithMetadata = StructType(Seq(df1.schema.fields.head.copy(metadata = metadata1.build()))) + val schema2WithMetadata = StructType(Seq(df2.schema.fields.head.copy(metadata = metadata2.build()))) + + val df1WithMetadata = spark.createDataFrame(df2.rdd, schema1WithMetadata) + + val schemaWithMetadata = SparkUtils.copyMetadata(df1WithMetadata.schema, schema2WithMetadata, sourcePreferred = true) + + val newDf = spark.createDataFrame(df2.rdd, schemaWithMetadata) + + assert(newDf.schema.fields.head.metadata.getString("comment") == "Test") + assert(newDf.schema.fields.head.metadata.getLong("maxLength") == 100) + assert(newDf.schema.fields.head.metadata.getLong("newMetadata") == 180) + } + test("copyMetadata should not retain original metadata when overwrite = true") { val df1 = List(1, 2, 3).toDF("col1") val df2 = List(1, 2, 3).toDF("col1")