From bade69bef9adaac9e52f8ec5a2ebfad78b0969df Mon Sep 17 00:00:00 2001 From: Ruslan Iushchenko Date: Wed, 2 Feb 2022 08:23:00 +0100 Subject: [PATCH] #466 Fix flattenSchema maximum array size for inner arrays. --- .../cobrix/spark/cobol/utils/SparkUtils.scala | 9 ++++ .../spark/cobol/utils/SparkUtilsSuite.scala | 54 ++++++++++++++++++- 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala index 21734b09..711af35c 100644 --- a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala +++ b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala @@ -290,20 +290,29 @@ object SparkUtils { state match { case 0 => + // The character might be part of the path if (c == '.') { fields.append(currentField.toString()) currentField = new StringBuilder() } else if (c == '`') { state = 1 + } else if (c == '[') { + state = 2 } else { currentField.append(c) } case 1 => + // The character is part of the backquoted field name if (c == '`') { state = 0 } else { currentField.append(c) } + case 2 => + // The character is an index (that should be ignored) + if (c == ']') { + state = 0 + } } i += 1 } diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala index ffbbb931..90df02c8 100644 --- a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala +++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala @@ -16,7 +16,7 @@ package za.co.absa.cobrix.spark.cobol.utils -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ArrayType, StructType} import org.scalatest.FunSuite import za.co.absa.cobrix.spark.cobol.source.base.SparkTestBase import org.slf4j.LoggerFactory @@ -349,6 +349,58 @@ class SparkUtilsSuite extends FunSuite with SparkTestBase with BinaryFileFixture assert(dfFlattened.count() == 0) } + test("Schema with multiple OCCURS should properly determine array sized") { + val copyBook: String = + """ 01 RECORD. + | 02 COUNT PIC 9(1). + | 02 GROUP OCCURS 2 TIMES. + | 03 INNER-COUNT PIC 9(1). + | 03 INNER-GROUP OCCURS 3 TIMES. + | 04 FIELD PIC X. + |""".stripMargin + + val expectedFlatSchema = + """root + | |-- COUNT: integer (nullable = true) + | |-- GROUP_0_INNER_COUNT: integer (nullable = true) + | |-- INNER_GROUP_0_FIELD: string (nullable = true) + | |-- INNER_GROUP_1_FIELD: string (nullable = true) + | |-- INNER_GROUP_2_FIELD: string (nullable = true) + | |-- GROUP_1_INNER_COUNT: integer (nullable = true) + | |-- INNER_GROUP_0_FIELD1: string (nullable = true) + | |-- INNER_GROUP_1_FIELD1: string (nullable = true) + | |-- INNER_GROUP_2_FIELD1: string (nullable = true) + |""".stripMargin.replace("\r\n", "\n") + + withTempTextFile("fletten", "test", StandardCharsets.UTF_8, "") { filePath => + val df = spark.read + .format("cobol") + .option("copybook_contents", copyBook) + .option("pedantic", "true") + .option("record_format", "D") + .load(filePath) + + val metadataStruct1 = df.schema.fields(1).metadata + val metadataInnerStruct = df.schema.fields(1).dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields(1).metadata + + assert(metadataStruct1.contains("minElements")) + assert(metadataInnerStruct.contains("minElements")) + assert(metadataStruct1.contains("maxElements")) + assert(metadataInnerStruct.contains("maxElements")) + + assert(metadataStruct1.getLong("minElements") == 0) + assert(metadataInnerStruct.getLong("minElements") == 0) + assert(metadataStruct1.getLong("maxElements") == 2) + assert(metadataInnerStruct.getLong("maxElements") == 3) + + val dfFlattened1 = SparkUtils.flattenSchema(df, useShortFieldNames = true) + val flatSchema1 = dfFlattened1.schema.treeString + + assertSchema(flatSchema1, expectedFlatSchema) + assert(dfFlattened1.count() == 0) + } + } + private def assertSchema(actualSchema: String, expectedSchema: String): Unit = { if (actualSchema != expectedSchema) { logger.error(s"EXPECTED:\n$expectedSchema")