Skip to content

Commit

Permalink
#685 Add methods for unstructing schemas and dataframes.
Browse files Browse the repository at this point in the history
This is similar to flattening, but does not flatten arrays, and it is more efficient.
  • Loading branch information
yruslan committed Jun 5, 2024
1 parent 5335013 commit 0903a48
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ object CobolParametersParser extends Logging {
val recordLengthFieldOpt = params.get(PARAM_RECORD_LENGTH_FIELD)
val isRecordSequence = Seq(FixedBlock, VariableLength, VariableBlock).contains(recordFormat)
val isRecordIdGenerationEnabled = params.getOrElse(PARAM_GENERATE_RECORD_ID, "false").toBoolean
val isSegmentIdGenerationEnabled = params.contains(s"${PARAM_SEGMENT_ID_LEVEL_PREFIX}0")
val fileStartOffset = params.getOrElse(PARAM_FILE_START_OFFSET, "0").toInt
val fileEndOffset = params.getOrElse(PARAM_FILE_END_OFFSET, "0").toInt
val varLenOccursEnabled = params.getOrElse(PARAM_VARIABLE_SIZE_OCCURS, "false").toBoolean
Expand All @@ -448,6 +449,7 @@ object CobolParametersParser extends Logging {
if (recordLengthFieldOpt.isDefined ||
isRecordSequence ||
isRecordIdGenerationEnabled ||
isSegmentIdGenerationEnabled ||
fileStartOffset > 0 ||
fileEndOffset > 0 ||
hasRecordExtractor ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,67 @@ object SparkUtils extends Logging {
df.select(fields.toSeq: _*)
}

/**
* Removes all struct nesting when possible for a given schema.
*/
def unstructSchema(schema: StructType, useShortFieldNames: Boolean = false): StructType = {
def mapFieldShort(field: StructField): Array[StructField] = {
field.dataType match {
case st: StructType =>
st.fields flatMap mapFieldShort
case _ =>
Array(field)
}
}

def mapFieldLong(field: StructField, path: String): Array[StructField] = {
field.dataType match {
case st: StructType =>
st.fields.flatMap(f => mapFieldLong(f, s"$path${field.name}_"))
case _ =>
Array(field.copy(name = s"$path${field.name}"))
}
}

val fields = if (useShortFieldNames)
schema.fields flatMap mapFieldShort
else
schema.fields.flatMap(f => mapFieldLong(f, ""))

StructType(fields)
}

/**
* Removes all struct nesting when possible for a given dataframe.
*
* Similar to `flattenSchema()`, but does not flatten arrays.
*/
def unstructDataFrame(df: DataFrame, useShortFieldNames: Boolean = false): DataFrame = {
def mapFieldShort(column: Column, field: StructField): Array[Column] = {
field.dataType match {
case st: StructType =>
st.fields.flatMap(f => mapFieldShort(column.getField(f.name), f))
case _ =>
Array(column.as(field.name, field.metadata))
}
}

def mapFieldLong(column: Column, field: StructField, path: String): Array[Column] = {
field.dataType match {
case st: StructType =>
st.fields.flatMap(f => mapFieldLong(column.getField(f.name), f, s"$path${field.name}_"))
case _ =>
Array(column.as(s"$path${field.name}", field.metadata))
}
}

val columns = if (useShortFieldNames)
df.schema.fields.flatMap(f => mapFieldShort(col(f.name), f))
else
df.schema.fields.flatMap(f => mapFieldLong(col(f.name), f, ""))
df.select(columns: _*)
}

/**
* Copies metadata from one schema to another as long as names and data types are the same.
*
Expand Down Expand Up @@ -237,7 +298,7 @@ object SparkUtils extends Logging {
def mapField(column: Column, field: StructField): Column = {
field.dataType match {
case st: StructType =>
val columns = st.fields.map(f => mapField(column.getField(field.name), f))
val columns = st.fields.map(f => mapField(column.getField(f.name), f))
struct(columns: _*).as(field.name, field.metadata)
case ar: ArrayType =>
mapArray(ar, column, field.name).as(field.name, field.metadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,180 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt
}
}

test("unstructDataFrame() and unstructSchema() should flatten a schema and the dataframe with short names") {
val copyBook: String =
""" 01 RECORD.
| 02 COUNT PIC 9(1).
| 02 GROUP1.
| 03 INNER-COUNT PIC S9(1).
| 03 INNER-GROUP OCCURS 3 TIMES.
| 04 FIELD PIC 9.
| 02 GROUP2.
| 03 INNER-COUNT PIC S9(1).
| 03 INNER-NUM PIC 9 OCCURS 3 TIMES.
|""".stripMargin

val expectedSchema =
"""|root
| |-- COUNT: integer (nullable = true)
| |-- INNER_COUNT: integer (nullable = true)
| |-- INNER_GROUP: array (nullable = true)
| | |-- element: struct (containsNull = true)
| | | |-- FIELD: integer (nullable = true)
| |-- INNER_COUNT: integer (nullable = true)
| |-- INNER_NUM: array (nullable = true)
| | |-- element: integer (containsNull = true)
|""".stripMargin

val expectedData =
"""[ {
| "COUNT" : 2,
| "INNER_COUNT" : 1,
| "INNER_GROUP" : [ {
| "FIELD" : 4
| }, {
| "FIELD" : 5
| }, {
| "FIELD" : 6
| } ],
| "INNER_NUM" : [ 7, 8, 9 ]
|}, {
| "COUNT" : 3,
| "INNER_COUNT" : 2,
| "INNER_GROUP" : [ {
| "FIELD" : 7
| }, {
| "FIELD" : 8
| }, {
| "FIELD" : 9
| } ],
| "INNER_NUM" : [ 4, 5, 6 ]
|} ]
|""".stripMargin

withTempTextFile("flatten", "test", StandardCharsets.UTF_8, "224561789\n347892456\n") { filePath =>
val df = spark.read
.format("cobol")
.option("copybook_contents", copyBook)
.option("pedantic", "true")
.option("record_format", "D")
.option("metadata", "extended")
.load(filePath)

val actualDf = SparkUtils.unstructDataFrame(df, useShortFieldNames = true)
val actualSchema = actualDf.schema.treeString
val actualSchemaOnly = SparkUtils.unstructSchema(df.schema, useShortFieldNames = true)
val actualSchema2 = actualSchemaOnly.treeString

compareText(actualSchema, expectedSchema)
compareText(actualSchema2, expectedSchema)

val actualData = SparkUtils.prettyJSON(actualDf.orderBy("COUNT").toJSON.collect().mkString("[", ", ", "]"))

compareText(actualData, expectedData)

assert(actualDf.schema.fields.head.metadata.json.nonEmpty)
assert(actualDf.schema.fields(1).metadata.json.nonEmpty)
assert(actualDf.schema.fields(2).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields.head.metadata.json.nonEmpty)
assert(actualDf.schema.fields(3).metadata.json.nonEmpty)
assert(actualDf.schema.fields(4).metadata.json.nonEmpty)

assert(actualSchemaOnly.fields.head.metadata.json.nonEmpty)
assert(actualSchemaOnly.fields(1).metadata.json.nonEmpty)
assert(actualSchemaOnly.fields(2).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields.head.metadata.json.nonEmpty)
assert(actualSchemaOnly.fields(3).metadata.json.nonEmpty)
assert(actualSchemaOnly.fields(4).metadata.json.nonEmpty)
}
}

test("unstructDataFrame() and unstructSchema() should flatten a schema and the dataframe with long names") {
val copyBook: String =
""" 01 RECORD.
| 02 COUNT PIC 9(1).
| 02 GROUP1.
| 03 INNER-COUNT PIC S9(1).
| 03 INNER-GROUP OCCURS 3 TIMES.
| 04 FIELD PIC 9.
| 02 GROUP2.
| 03 INNER-COUNT PIC S9(1).
| 03 INNER-NUM PIC 9 OCCURS 3 TIMES.
|""".stripMargin

val expectedSchema =
"""|root
| |-- COUNT: integer (nullable = true)
| |-- GROUP1_INNER_COUNT: integer (nullable = true)
| |-- GROUP1_INNER_GROUP: array (nullable = true)
| | |-- element: struct (containsNull = true)
| | | |-- FIELD: integer (nullable = true)
| |-- GROUP2_INNER_COUNT: integer (nullable = true)
| |-- GROUP2_INNER_NUM: array (nullable = true)
| | |-- element: integer (containsNull = true)
|""".stripMargin

val expectedData =
"""[ {
| "COUNT" : 2,
| "GROUP1_INNER_COUNT" : 2,
| "GROUP1_INNER_GROUP" : [ {
| "FIELD" : 4
| }, {
| "FIELD" : 5
| }, {
| "FIELD" : 6
| } ],
| "GROUP2_INNER_COUNT" : 1,
| "GROUP2_INNER_NUM" : [ 7, 8, 9 ]
|}, {
| "COUNT" : 3,
| "GROUP1_INNER_COUNT" : 4,
| "GROUP1_INNER_GROUP" : [ {
| "FIELD" : 7
| }, {
| "FIELD" : 8
| }, {
| "FIELD" : 9
| } ],
| "GROUP2_INNER_COUNT" : 2,
| "GROUP2_INNER_NUM" : [ 4, 5, 6 ]
|} ]
|""".stripMargin

withTempTextFile("flatten", "test", StandardCharsets.UTF_8, "224561789\n347892456\n") { filePath =>
val df = spark.read
.format("cobol")
.option("copybook_contents", copyBook)
.option("pedantic", "true")
.option("record_format", "D")
.option("metadata", "extended")
.load(filePath)

val actualDf = SparkUtils.unstructDataFrame(df)
val actualSchema = actualDf.schema.treeString
val actualSchemaOnly = SparkUtils.unstructSchema(df.schema)
val actualSchema2 = actualSchemaOnly.treeString

compareText(actualSchema, expectedSchema)
compareText(actualSchema2, expectedSchema)

val actualData = SparkUtils.prettyJSON(actualDf.orderBy("COUNT").toJSON.collect().mkString("[", ", ", "]"))

compareText(actualData, expectedData)

assert(actualDf.schema.fields.head.metadata.json.nonEmpty)
assert(actualDf.schema.fields(1).metadata.json.nonEmpty)
assert(actualDf.schema.fields(2).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields.head.metadata.json.nonEmpty)
assert(actualDf.schema.fields(3).metadata.json.nonEmpty)
assert(actualDf.schema.fields(4).metadata.json.nonEmpty)

assert(actualSchemaOnly.fields.head.metadata.json.nonEmpty)
assert(actualSchemaOnly.fields(1).metadata.json.nonEmpty)
assert(actualSchemaOnly.fields(2).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields.head.metadata.json.nonEmpty)
assert(actualSchemaOnly.fields(3).metadata.json.nonEmpty)
assert(actualSchemaOnly.fields(4).metadata.json.nonEmpty)
}
}

test("Integral to decimal conversion for complex schema") {
val expectedSchema =
"""|root
Expand Down

0 comments on commit 0903a48

Please sign in to comment.