Skip to content

Commit

Permalink
#697 Improve metadata merging method in Spark Utils.
Browse files Browse the repository at this point in the history
  • Loading branch information
yruslan committed Jul 31, 2024
1 parent 8ca815c commit 6d89cec
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,19 @@ object SparkUtils extends Logging {
*
* @param schemaFrom Schema to copy metadata from.
* @param schemaTo Schema to copy metadata to.
* @param overwrite If true, the metadata of schemaTo is not retained
* @return Same schema as schemaTo with metadata from schemaFrom.
*/
def copyMetadata(schemaFrom: StructType, schemaTo: StructType): StructType = {
def copyMetadata(schemaFrom: StructType, schemaTo: StructType, overwrite: Boolean = false): StructType = {
def joinMetadata(from: Metadata, to: Metadata): Metadata = {
val newMetadataMerged = new MetadataBuilder

newMetadataMerged.withMetadata(from)
newMetadataMerged.withMetadata(to)

newMetadataMerged.build()
}

@tailrec
def processArray(ar: ArrayType, fieldFrom: StructField, fieldTo: StructField): ArrayType = {
ar.elementType match {
Expand All @@ -267,15 +277,21 @@ object SparkUtils extends Logging {
val newFields: Array[StructField] = schemaTo.fields.map { fieldTo =>
fieldsMap.get(fieldTo.name) match {
case Some(fieldFrom) =>
val newMetadata = if (overwrite) {
fieldFrom.metadata
} else {
joinMetadata(fieldFrom.metadata, fieldTo.metadata)
}

fieldTo.dataType match {
case st: StructType if fieldFrom.dataType.isInstanceOf[StructType] =>
val newDataType = StructType(copyMetadata(fieldFrom.dataType.asInstanceOf[StructType], st).fields)
fieldTo.copy(dataType = newDataType, metadata = fieldFrom.metadata)
fieldTo.copy(dataType = newDataType, metadata = newMetadata)
case at: ArrayType =>
val newType = processArray(at, fieldFrom, fieldTo)
fieldTo.copy(dataType = newType, metadata = fieldFrom.metadata)
fieldTo.copy(dataType = newType, metadata = newMetadata)
case _ =>
fieldTo.copy(metadata = fieldFrom.metadata)
fieldTo.copy(metadata = newMetadata)
}
case None =>
fieldTo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -603,22 +603,73 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt
}
}

test("copyMetadata should copy metadata from one schema to another") {
test("copyMetadata should copy metadata from one schema to another when overwrite = false") {
val df1 = List(1, 2, 3).toDF("col1")
val df2 = List(1, 2, 3).toDF("col1")

val metadata1 = new MetadataBuilder()
metadata1.putString("comment", "Test")

val metadata2 = new MetadataBuilder()
metadata2.putLong("maxLength", 120)

val schema1WithMetadata = StructType(Seq(df1.schema.fields.head.copy(metadata = metadata1.build())))
val schema2WithMetadata = StructType(Seq(df2.schema.fields.head.copy(metadata = metadata2.build())))

val df1WithMetadata = spark.createDataFrame(df2.rdd, schema1WithMetadata)

val schemaWithMetadata = SparkUtils.copyMetadata(df1WithMetadata.schema, schema2WithMetadata)

val newDf = spark.createDataFrame(df2.rdd, schemaWithMetadata)

assert(newDf.schema.fields.head.metadata.getString("comment") == "Test")
assert(newDf.schema.fields.head.metadata.getLong("maxLength") == 120)
}

test("copyMetadata should not retain original metadata when overwrite = true") {
val df1 = List(1, 2, 3).toDF("col1")
val df2 = List(1, 2, 3).toDF("col1")

val metadata1 = new MetadataBuilder()
metadata1.putString("comment", "Test")

val metadata2 = new MetadataBuilder()
metadata2.putLong("maxLength", 120)

val schema1WithMetadata = StructType(Seq(df1.schema.fields.head.copy(metadata = metadata1.build())))
val schema2WithMetadata = StructType(Seq(df2.schema.fields.head.copy(metadata = metadata2.build())))

val df1WithMetadata = spark.createDataFrame(df2.rdd, schema1WithMetadata)

val schemaWithMetadata = SparkUtils.copyMetadata(df1WithMetadata.schema, df2.schema)
val schemaWithMetadata = SparkUtils.copyMetadata(df1WithMetadata.schema, schema2WithMetadata, overwrite = true)

val newDf = spark.createDataFrame(df2.rdd, schemaWithMetadata)

assert(newDf.schema.fields.head.metadata.getString("comment") == "Test")
assert(!newDf.schema.fields.head.metadata.contains("maxLength"))
}

test("Make sure flattenning does not remove metadata") {
val df1 = List(1, 2, 3).toDF("col1")
val df2 = List(1, 2, 3).toDF("col1")

val metadata1 = new MetadataBuilder()
metadata1.putString("comment", "Test")

val metadata2 = new MetadataBuilder()
metadata2.putLong("maxLength", 120)

val schema1WithMetadata = StructType(Seq(df1.schema.fields.head.copy(metadata = metadata1.build())))
val schema2WithMetadata = StructType(Seq(df2.schema.fields.head.copy(metadata = metadata2.build())))

val df1WithMetadata = spark.createDataFrame(df2.rdd, schema1WithMetadata)

val schemaWithMetadata = SparkUtils.copyMetadata(df1WithMetadata.schema, schema2WithMetadata)

val newDf = SparkUtils.unstructDataFrame(spark.createDataFrame(df2.rdd, schemaWithMetadata))

assert(newDf.schema.fields.head.metadata.getString("comment") == "Test")
assert(newDf.schema.fields.head.metadata.getLong("maxLength") == 120)
}

test("Integral to decimal conversion for complex schema") {
Expand Down

0 comments on commit 6d89cec

Please sign in to comment.