Skip to content

Commit

Permalink
#731 Add an option to copy data type when copying metadata.
Browse files Browse the repository at this point in the history
  • Loading branch information
yruslan committed Dec 19, 2024
1 parent f9be3c7 commit 6348cab
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,21 @@ object SparkUtils extends Logging {
allExecutors.filter(!_.equals(driverHost)).toList.distinct
}

/**
* Returns true if Spark Data type is a primitive data type.
*
* @param dataType Stark data type
* @return true if the data type is primitive.
*/
def isPrimitive(dataType: DataType): Boolean = {
dataType match {
case _: ArrayType => false
case _: StructType => false
case _: MapType => false
case _ => true
}
}

/**
* Given an instance of DataFrame returns a dataframe with flattened schema.
* All nested structures are flattened and arrays are projected as columns.
Expand Down Expand Up @@ -248,12 +263,14 @@ object SparkUtils extends Logging {
* @param schemaTo Schema to copy metadata to.
* @param overwrite If true, the metadata of schemaTo is not retained
* @param sourcePreferred If true, schemaFrom metadata is used on conflicts, schemaTo otherwise.
* @param copyDataType If true, data type is copied as well. This is limited to primitive data types.
* @return Same schema as schemaTo with metadata from schemaFrom.
*/
def copyMetadata(schemaFrom: StructType,
schemaTo: StructType,
overwrite: Boolean = false,
sourcePreferred: Boolean = false): StructType = {
sourcePreferred: Boolean = false,
copyDataType: Boolean = false): StructType = {
def joinMetadata(from: Metadata, to: Metadata): Metadata = {
val newMetadataMerged = new MetadataBuilder

Expand All @@ -273,12 +290,16 @@ object SparkUtils extends Logging {
ar.elementType match {
case st: StructType if fieldFrom.dataType.isInstanceOf[ArrayType] && fieldFrom.dataType.asInstanceOf[ArrayType].elementType.isInstanceOf[StructType] =>
val innerStructFrom = fieldFrom.dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType]
val newDataType = StructType(copyMetadata(innerStructFrom, st).fields)
val newDataType = StructType(copyMetadata(innerStructFrom, st, overwrite, sourcePreferred, copyDataType).fields)
ArrayType(newDataType, ar.containsNull)
case at: ArrayType =>
processArray(at, fieldFrom, fieldTo)
case p =>
ArrayType(p, ar.containsNull)
if (copyDataType && fieldFrom.dataType.isInstanceOf[ArrayType] && isPrimitive(fieldFrom.dataType.asInstanceOf[ArrayType].elementType)) {
ArrayType(fieldFrom.dataType.asInstanceOf[ArrayType].elementType, ar.containsNull)
} else {
ArrayType(p, ar.containsNull)
}
}
}

Expand All @@ -295,13 +316,17 @@ object SparkUtils extends Logging {

fieldTo.dataType match {
case st: StructType if fieldFrom.dataType.isInstanceOf[StructType] =>
val newDataType = StructType(copyMetadata(fieldFrom.dataType.asInstanceOf[StructType], st).fields)
val newDataType = StructType(copyMetadata(fieldFrom.dataType.asInstanceOf[StructType], st, overwrite, sourcePreferred, copyDataType).fields)
fieldTo.copy(dataType = newDataType, metadata = newMetadata)
case at: ArrayType =>
val newType = processArray(at, fieldFrom, fieldTo)
fieldTo.copy(dataType = newType, metadata = newMetadata)
case _ =>
fieldTo.copy(metadata = newMetadata)
if (copyDataType && isPrimitive(fieldFrom.dataType)) {
fieldTo.copy(dataType = fieldFrom.dataType, metadata = newMetadata)
} else {
fieldTo.copy(metadata = newMetadata)
}
}
case None =>
fieldTo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,24 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt
"""[{"id":4,"legs":[]}]""" ::
"""[{"id":5,"legs":null}]""" :: Nil

test("IsPrimitive should work as expected") {
assert(SparkUtils.isPrimitive(BooleanType))
assert(SparkUtils.isPrimitive(ByteType))
assert(SparkUtils.isPrimitive(ShortType))
assert(SparkUtils.isPrimitive(IntegerType))
assert(SparkUtils.isPrimitive(LongType))
assert(SparkUtils.isPrimitive(FloatType))
assert(SparkUtils.isPrimitive(DoubleType))
assert(SparkUtils.isPrimitive(DecimalType(10, 2)))
assert(SparkUtils.isPrimitive(StringType))
assert(SparkUtils.isPrimitive(BinaryType))
assert(SparkUtils.isPrimitive(DateType))
assert(SparkUtils.isPrimitive(TimestampType))
assert(!SparkUtils.isPrimitive(ArrayType(StringType)))
assert(!SparkUtils.isPrimitive(StructType(Seq(StructField("a", StringType)))))
assert(!SparkUtils.isPrimitive(MapType(StringType, StringType)))
}

test("Test schema flattening of multiple nested structure") {
val expectedOrigSchema =
"""root
Expand Down Expand Up @@ -626,6 +644,85 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt
assert(newDf.schema.fields.head.metadata.getLong("maxLength") == 120)
}

test("copyMetadata should copy primitive data types when it is enabled") {
val schemaFrom = StructType(
Seq(
StructField("int_field1", IntegerType, nullable = true, metadata = new MetadataBuilder().putString("comment", "Test1").build()),
StructField("string_field", StringType, nullable = true, metadata = new MetadataBuilder().putLong("maxLength", 120).build()),
StructField("int_field2", StructType(
Seq(
StructField("int_field20", IntegerType, nullable = true, metadata = new MetadataBuilder().putString("comment", "Test20").build())
)
), nullable = true),
StructField("struct_field2", StructType(
Seq(
StructField("int_field3", IntegerType, nullable = true, metadata = new MetadataBuilder().putString("comment", "Test3").build())
)
), nullable = true),
StructField("array_string", ArrayType(StringType), nullable = true, metadata = new MetadataBuilder().putLong("maxLength", 60).build()),
StructField("array_struct", ArrayType(StructType(
Seq(
StructField("int_field4", IntegerType, nullable = true, metadata = new MetadataBuilder().putString("comment", "Test4").build())
)
)), nullable = true)
)
)

val schemaTo = StructType(
Seq(
StructField("int_field1", BooleanType, nullable = true),
StructField("string_field", IntegerType, nullable = true),
StructField("int_field2", IntegerType, nullable = true),
StructField("struct_field2", StructType(
Seq(
StructField("int_field3", BooleanType, nullable = true)
)
), nullable = true),
StructField("array_string", ArrayType(IntegerType), nullable = true),
StructField("array_struct", ArrayType(StructType(
Seq(
StructField("int_field4", StringType, nullable = true)
)
)), nullable = true)
)
)

val schemaWithMetadata = SparkUtils.copyMetadata(schemaFrom, schemaTo, copyDataType = true)
val fields = schemaWithMetadata.fields

// Ensure data types are copied
// Expected schema:
// root
// |-- int_field1: boolean (nullable = true)
// |-- string_field: integer (nullable = true)
// |-- int_field2: integer (nullable = true)
// |-- struct_field2: struct (nullable = true)
// | |-- int_field3: boolean (nullable = true)
// |-- array_string: array (nullable = true)
// | |-- element: integer (containsNull = true)
// |-- array_struct: array (nullable = true)
// | |-- element: struct (containsNull = true)
// | | |-- int_field4: string (nullable = true)
assert(fields.head.dataType == IntegerType)
assert(fields(1).dataType == StringType)
assert(fields(2).dataType == IntegerType)
assert(fields(3).dataType.isInstanceOf[StructType])
assert(fields(4).dataType.isInstanceOf[ArrayType])
assert(fields(5).dataType.isInstanceOf[ArrayType])

assert(fields(3).dataType.asInstanceOf[StructType].fields.head.dataType == IntegerType)
assert(fields(4).dataType.asInstanceOf[ArrayType].elementType == StringType)
assert(fields(5).dataType.asInstanceOf[ArrayType].elementType.isInstanceOf[StructType])
assert(fields(5).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields.head.dataType == IntegerType)

// Ensure metadata is copied
assert(fields.head.metadata.getString("comment") == "Test1")
assert(fields(1).metadata.getLong("maxLength") == 120)
assert(fields(3).dataType.asInstanceOf[StructType].fields.head.metadata.getString("comment") == "Test3")
assert(fields(4).metadata.getLong("maxLength") == 60)
assert(fields(5).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields.head.metadata.getString("comment") == "Test4")
}

test("copyMetadata should retain metadata on conflicts by default") {
val df1 = List(1, 2, 3).toDF("col1")
val df2 = List(1, 2, 3).toDF("col1")
Expand Down

0 comments on commit 6348cab

Please sign in to comment.