Skip to content

Commit

Permalink
#466 Use schema metadata in the flatten() routine to determine array …
Browse files Browse the repository at this point in the history
…sizes.
  • Loading branch information
yruslan committed Feb 1, 2022
1 parent 5578ef6 commit 55549a8
Show file tree
Hide file tree
Showing 2 changed files with 337 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import org.slf4j.LoggerFactory

import scala.annotation.tailrec
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import scala.util.Try

/**
* This object contains common Spark tools used for easier processing of dataframes originated from mainframes.
Expand Down Expand Up @@ -52,7 +54,7 @@ object SparkUtils {
* Note. The method checks the maximum size for each array and that could perform slowly,
* especially on a vary big dataframes.
*
* @param df A dataframe
* @param df A dataframe
* @param useShortFieldNames When flattening a schema each field name will contain full path. You can override this
* behavior and use a short field names instead
* @return A new dataframe with flat schema.
Expand Down Expand Up @@ -82,7 +84,7 @@ object SparkUtils {
* @param arrayType ArrayType
*/
def flattenStructArray(path: String, fieldNamePrefix: String, structField: StructField, arrayType: ArrayType): Unit = {
val maxInd = df.agg(max(expr(s"size($path${structField.name})"))).collect()(0)(0).toString.toInt
val maxInd = getMaxArraySize(s"$path${structField.name}")
var i = 0
while (i < maxInd) {
arrayType.elementType match {
Expand All @@ -104,7 +106,7 @@ object SparkUtils {
}

def flattenNestedArrays(path: String, fieldNamePrefix: String, arrayType: ArrayType): Unit = {
val maxInd = df.agg(max(expr(s"size($path)"))).collect()(0)(0).toString.toInt
val maxInd = getMaxArraySize(path)
var i = 0
while (i < maxInd) {
arrayType.elementType match {
Expand All @@ -125,6 +127,21 @@ object SparkUtils {
}
}

def getMaxArraySize(path: String): Int = {
getField(path, df.schema) match {
case Some(field) if field.metadata.contains("maxElements") =>
field.metadata.getLong("maxElements").toInt
case _ =>
val collected = df.agg(max(expr(s"size($path)"))).collect()(0)(0)
if (collected != null) {
// can be null for empty dataframe
collected.toString.toInt
} else {
1
}
}
}

def flattenArray(path: String, fieldNamePrefix: String, structField: StructField, arrayType: ArrayType): Unit = {
arrayType.elementType match {
case _: ArrayType =>
Expand Down Expand Up @@ -221,4 +238,80 @@ object SparkUtils {
indented.replace("\r\n", "\n")
}

/**
* Get a Spark field from a text path and a given schema
* (originally implemented here: https://github.com/AbsaOSS/enceladus/blob/665b34fa1c04fe255729e4b6706cf9ea33227b3e/utils/src/main/scala/za/co/absa/enceladus/utils/schema/SchemaUtils.scala#L45)
*
* @param path The dot-separated path to the field
* @param schema The schema which should contain the specified path
* @return Some(the requested field) or None if the field does not exist
*/
def getField(path: String, schema: StructType): Option[StructField] = {
@tailrec
def goThroughArrayDataType(dataType: DataType): DataType = {
dataType match {
case ArrayType(dt, _) => goThroughArrayDataType(dt)
case result => result
}
}

@tailrec
def examineStructField(names: List[String], structField: StructField): Option[StructField] = {
if (names.isEmpty) {
Option(structField)
} else {
structField.dataType match {
case struct: StructType => examineStructField(names.tail, struct(names.head))
case ArrayType(el: DataType, _) =>
goThroughArrayDataType(el) match {
case struct: StructType => examineStructField(names.tail, struct(names.head))
case _ => None
}
case _ => None
}
}
}

val pathTokens = splitFieldPath(path)
Try {
examineStructField(pathTokens.tail, schema(pathTokens.head))
}.getOrElse(None)
}

private def splitFieldPath(path: String): List[String] = {
var state = 0

var currentField = new StringBuilder()
val fields = new ListBuffer[String]()

var i = 0
while (i < path.length) {
val c = path(i)

state match {
case 0 =>
if (c == '.') {
fields.append(currentField.toString())
currentField = new StringBuilder()
} else if (c == '`') {
state = 1
} else {
currentField.append(c)
}
case 1 =>
if (c == '`') {
state = 0
} else {
currentField.append(c)
}
}
i += 1
}
if (currentField.nonEmpty) {
fields.append(currentField.toString())
}
fields.toList
}


}
Loading

0 comments on commit 55549a8

Please sign in to comment.