Skip to content

Commit

Permalink
#672 Add unit tests for Spark schema generation.
Browse files Browse the repository at this point in the history
  • Loading branch information
yruslan committed Apr 22, 2024
1 parent b2a5434 commit 1f9bfb5
Showing 1 changed file with 77 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package za.co.absa.cobrix.spark.cobol

import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructType}
import org.scalatest.wordspec.AnyWordSpec
import org.slf4j.{Logger, LoggerFactory}
import za.co.absa.cobrix.cobol.parser.CopybookParser
Expand Down Expand Up @@ -409,14 +409,14 @@ class CobolSchemaSpec extends AnyWordSpec with SimpleComparisonBase {

"fromSparkOptions" should {
"return a schema for a copybook" in {
val copyBook: String =
val copybook: String =
""" 01 RECORD.
| 05 STR1 PIC X(10).
| 05 STR2 PIC A(7).
| 05 NUM3 PIC 9(7).
|""".stripMargin

val cobolSchema = CobolSchema.fromSparkOptions(Seq(copyBook), Map.empty)
val cobolSchema = CobolSchema.fromSparkOptions(Seq(copybook), Map.empty)

val sparkSchema = cobolSchema.getSparkSchema

Expand All @@ -428,6 +428,80 @@ class CobolSchemaSpec extends AnyWordSpec with SimpleComparisonBase {
assert(sparkSchema.fields(2).name == "NUM3")
assert(sparkSchema.fields(2).dataType == IntegerType)
}

"return a schema for multiple copybooks" in {
val copybook1: String =
""" 01 RECORD1.
| 05 STR1 PIC X(10).
| 05 STR2 PIC A(7).
| 05 NUM3 PIC 9(7).
|""".stripMargin

val copybook2: String =
""" 01 RECORD2.
| 05 STR4 PIC X(10).
| 05 STR5 PIC A(7).
| 05 NUM6 PIC 9(7).
|""".stripMargin

val cobolSchema = CobolSchema.fromSparkOptions(Seq(copybook1, copybook2), Map("schema_retention_policy" -> "keep_original"))

val sparkSchema = cobolSchema.getSparkSchema

assert(sparkSchema.fields.length == 2)
assert(sparkSchema.fields.head.name == "RECORD1")
assert(sparkSchema.fields.head.dataType.isInstanceOf[StructType])
assert(sparkSchema.fields(1).name == "RECORD2")
assert(sparkSchema.fields(1).dataType.isInstanceOf[StructType])
assert(cobolSchema.getCobolSchema.ast.children.head.isRedefined)
assert(cobolSchema.getCobolSchema.ast.children(1).redefines.contains("RECORD1"))
}

"return a schema for a hierarchical copybook" in {
val copybook: String =
""" 01 RECORD.
| 05 HEADER PIC X(5).
| 05 SEGMENT-ID PIC X(2).
| 05 SEG1.
| 10 FIELD1 PIC 9(7).
| 05 SEG2 REDEFINES SEG1.
| 10 FIELD3 PIC X(7).
| 05 SEG3 REDEFINES SEG1.
| 10 FIELD4 PIC S9(7).
|""".stripMargin

val cobolSchema = CobolSchema.fromSparkOptions(Seq(copybook),
Map(
"segment_field" -> "SEGMENT-ID",
"redefine-segment-id-map:0" -> "SEG1 => 01",
"redefine-segment-id-map:1" -> "SEG2 => 02",
"redefine-segment-id-map:2" -> "SEG3 => 03,0A",
"segment-children:1" -> "SEG1 => SEG2",
"segment-children:2" -> "SEG1 => SEG3"
)
)

val sparkSchema = cobolSchema.getSparkSchema

sparkSchema.printTreeString()

assert(sparkSchema.fields.length == 3)
assert(sparkSchema.fields.head.name == "HEADER")
assert(sparkSchema.fields.head.dataType == StringType)
assert(sparkSchema.fields(1).name == "SEGMENT_ID")
assert(sparkSchema.fields(1).dataType == StringType)
assert(sparkSchema.fields(2).name == "SEG1")
assert(sparkSchema.fields(2).dataType.isInstanceOf[StructType])

val seg1 = sparkSchema.fields(2).dataType.asInstanceOf[StructType]
assert(seg1.fields.length == 3)
assert(seg1.fields.head.name == "FIELD1")
assert(seg1.fields.head.dataType == IntegerType)
assert(seg1.fields(1).name == "SEG2")
assert(seg1.fields(1).dataType.isInstanceOf[ArrayType])
assert(seg1.fields(2).name == "SEG3")
assert(seg1.fields(2).dataType.isInstanceOf[ArrayType])
}
}

}

0 comments on commit 1f9bfb5

Please sign in to comment.