From 1f9bfb5be116dd3974e988692812fe3f3779a4a4 Mon Sep 17 00:00:00 2001 From: Ruslan Iushchenko Date: Fri, 19 Apr 2024 11:47:48 +0200 Subject: [PATCH] #672 Add unit tests for Spark schema generation. --- .../cobrix/spark/cobol/CobolSchemaSpec.scala | 80 ++++++++++++++++++- 1 file changed, 77 insertions(+), 3 deletions(-) diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/CobolSchemaSpec.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/CobolSchemaSpec.scala index e4570731..0229065c 100644 --- a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/CobolSchemaSpec.scala +++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/CobolSchemaSpec.scala @@ -16,7 +16,7 @@ package za.co.absa.cobrix.spark.cobol -import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructType} import org.scalatest.wordspec.AnyWordSpec import org.slf4j.{Logger, LoggerFactory} import za.co.absa.cobrix.cobol.parser.CopybookParser @@ -409,14 +409,14 @@ class CobolSchemaSpec extends AnyWordSpec with SimpleComparisonBase { "fromSparkOptions" should { "return a schema for a copybook" in { - val copyBook: String = + val copybook: String = """ 01 RECORD. | 05 STR1 PIC X(10). | 05 STR2 PIC A(7). | 05 NUM3 PIC 9(7). |""".stripMargin - val cobolSchema = CobolSchema.fromSparkOptions(Seq(copyBook), Map.empty) + val cobolSchema = CobolSchema.fromSparkOptions(Seq(copybook), Map.empty) val sparkSchema = cobolSchema.getSparkSchema @@ -428,6 +428,80 @@ class CobolSchemaSpec extends AnyWordSpec with SimpleComparisonBase { assert(sparkSchema.fields(2).name == "NUM3") assert(sparkSchema.fields(2).dataType == IntegerType) } + + "return a schema for multiple copybooks" in { + val copybook1: String = + """ 01 RECORD1. + | 05 STR1 PIC X(10). + | 05 STR2 PIC A(7). + | 05 NUM3 PIC 9(7). + |""".stripMargin + + val copybook2: String = + """ 01 RECORD2. + | 05 STR4 PIC X(10). + | 05 STR5 PIC A(7). + | 05 NUM6 PIC 9(7). + |""".stripMargin + + val cobolSchema = CobolSchema.fromSparkOptions(Seq(copybook1, copybook2), Map("schema_retention_policy" -> "keep_original")) + + val sparkSchema = cobolSchema.getSparkSchema + + assert(sparkSchema.fields.length == 2) + assert(sparkSchema.fields.head.name == "RECORD1") + assert(sparkSchema.fields.head.dataType.isInstanceOf[StructType]) + assert(sparkSchema.fields(1).name == "RECORD2") + assert(sparkSchema.fields(1).dataType.isInstanceOf[StructType]) + assert(cobolSchema.getCobolSchema.ast.children.head.isRedefined) + assert(cobolSchema.getCobolSchema.ast.children(1).redefines.contains("RECORD1")) + } + + "return a schema for a hierarchical copybook" in { + val copybook: String = + """ 01 RECORD. + | 05 HEADER PIC X(5). + | 05 SEGMENT-ID PIC X(2). + | 05 SEG1. + | 10 FIELD1 PIC 9(7). + | 05 SEG2 REDEFINES SEG1. + | 10 FIELD3 PIC X(7). + | 05 SEG3 REDEFINES SEG1. + | 10 FIELD4 PIC S9(7). + |""".stripMargin + + val cobolSchema = CobolSchema.fromSparkOptions(Seq(copybook), + Map( + "segment_field" -> "SEGMENT-ID", + "redefine-segment-id-map:0" -> "SEG1 => 01", + "redefine-segment-id-map:1" -> "SEG2 => 02", + "redefine-segment-id-map:2" -> "SEG3 => 03,0A", + "segment-children:1" -> "SEG1 => SEG2", + "segment-children:2" -> "SEG1 => SEG3" + ) + ) + + val sparkSchema = cobolSchema.getSparkSchema + + sparkSchema.printTreeString() + + assert(sparkSchema.fields.length == 3) + assert(sparkSchema.fields.head.name == "HEADER") + assert(sparkSchema.fields.head.dataType == StringType) + assert(sparkSchema.fields(1).name == "SEGMENT_ID") + assert(sparkSchema.fields(1).dataType == StringType) + assert(sparkSchema.fields(2).name == "SEG1") + assert(sparkSchema.fields(2).dataType.isInstanceOf[StructType]) + + val seg1 = sparkSchema.fields(2).dataType.asInstanceOf[StructType] + assert(seg1.fields.length == 3) + assert(seg1.fields.head.name == "FIELD1") + assert(seg1.fields.head.dataType == IntegerType) + assert(seg1.fields(1).name == "SEG2") + assert(seg1.fields(1).dataType.isInstanceOf[ArrayType]) + assert(seg1.fields(2).name == "SEG3") + assert(seg1.fields(2).dataType.isInstanceOf[ArrayType]) + } } }