Skip to content

Commit

Permalink
#97 ControlMeasureBuilder.withAggregateColumn(s) implementations. Tod…
Browse files Browse the repository at this point in the history
…o needs testing, cleanup & documentation update

MeasurementProcessor split into object/class to offer generic processing methods to be reusable.
  • Loading branch information
dk1844 committed Aug 3, 2021
1 parent f526eec commit fbe44f2
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 132 deletions.
10 changes: 7 additions & 3 deletions atum/src/main/scala/za/co/absa/atum/core/ControlType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ object ControlType {
case object AbsAggregatedTotal extends ControlType("absAggregatedTotal")
case object HashCrc32 extends ControlType("hashCrc32")

val values = Seq(Count.value, DistinctCount.value, AggregatedTotal.value, AbsAggregatedTotal.value, HashCrc32.value)
val values = Seq(Count, DistinctCount, AggregatedTotal, AbsAggregatedTotal, HashCrc32)
val valueStrings = values.map(_.value)

def getNormalizedValue(input: String): String = {
values.find(value => isControlMeasureTypeEqual(input, value)).getOrElse(input)
def getNormalizedStringValue(input: String): String = {
valueStrings.find(value => isControlMeasureTypeEqual(input, value)).getOrElse(input)
}

def withValueName(s: String): ControlType = values.find(_.value.toString == s).getOrElse(
throw new NoSuchElementException(s"No value found for '$s'. Allowed values are: $valueStrings"))

def isControlMeasureTypeEqual(x: String, y: String): Boolean = {
if (x.toLowerCase == y.toLowerCase) {
true
Expand Down
168 changes: 96 additions & 72 deletions atum/src/main/scala/za/co/absa/atum/core/MeasurementProcessor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,109 @@ package za.co.absa.atum.core

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DecimalType, LongType, StringType}
import org.apache.spark.sql.{Column, Dataset, Row}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import za.co.absa.atum.core.ControlType._
import za.co.absa.atum.core.MeasurementProcessor.MeasurementFunction
import za.co.absa.atum.model.Measurement
import za.co.absa.atum.utils.controlmeasure.ControlMeasureUtils

import scala.util.{Failure, Success, Try}

object MeasurementProcessor {
type MeasurementFunction = DataFrame => String

private val valueColumnName: String = "value"

/**
* Assembles the measurement function based for controlCol based on the controlType
* @param controlCol
* @param controlType
* @return
*/
private[atum] def getMeasurementFunction(controlCol: String, controlType: ControlType): MeasurementFunction = {
controlType match {
case Count => (ds: Dataset[Row]) => ds.count().toString
case DistinctCount => (ds: Dataset[Row]) => {
ds.select(col(controlCol)).distinct().count().toString
}
case AggregatedTotal =>
(ds: Dataset[Row]) => {
val aggCol = sum(col(valueColumnName))
aggregateColumn(ds, controlCol, aggCol)
}
case AbsAggregatedTotal =>
(ds: Dataset[Row]) => {
val aggCol = sum(abs(col(valueColumnName)))
aggregateColumn(ds, controlCol, aggCol)
}
case HashCrc32 =>
(ds: Dataset[Row]) => {
val aggColName = ControlMeasureUtils.getTemporaryColumnName(ds)
val v = ds.withColumn(aggColName, crc32(col(controlCol).cast("String")))
.agg(sum(col(aggColName))).collect()(0)(0)
if (v == null) "" else v.toString
}
}
}

private def aggregateColumn(ds: Dataset[Row], measureColumn: String, aggExpression: Column): String = {
val dataType = ds.select(measureColumn).schema.fields(0).dataType
val aggregatedValue = dataType match {
case _: LongType =>
// This is protection against long overflow, e.g. Long.MaxValue = 9223372036854775807:
// scala> sc.parallelize(List(Long.MaxValue, 1)).toDF.agg(sum("value")).take(1)(0)(0)
// res11: Any = -9223372036854775808
// Converting to BigDecimal fixes the issue
//val ds2 = ds.select(col(measurement.controlCol).cast(DecimalType(38, 0)).as("value"))
//ds2.agg(sum(abs($"value"))).collect()(0)(0)
val ds2 = ds.select(col(measureColumn).cast(DecimalType(38, 0)).as(valueColumnName))
val collected = ds2.agg(aggExpression).collect()(0)(0)
if (collected == null) 0 else collected
case _: StringType =>
// Support for string type aggregation
val ds2 = ds.select(col(measureColumn).cast(DecimalType(38, 18)).as(valueColumnName))
val collected = ds2.agg(aggExpression).collect()(0)(0)
val value = if (collected==null) new java.math.BigDecimal(0) else collected.asInstanceOf[java.math.BigDecimal]
value
.stripTrailingZeros // removes trailing zeros (2001.500000 -> 2001.5, but can introduce scientific notation (600.000 -> 6E+2)
.toPlainString // converts to normal string (6E+2 -> "600")
case _ =>
val ds2 = ds.select(col(measureColumn).as(valueColumnName))
val collected = ds2.agg(aggExpression).collect()(0)(0)
if (collected == null) 0 else collected
}
//check if total is required to be presented as larger type - big decimal
workaroundBigDecimalIssues(aggregatedValue)
}

private def workaroundBigDecimalIssues(value: Any): String = {
// If aggregated value is java.math.BigDecimal, convert it to scala.math.BigDecimal
value match {
case v: java.math.BigDecimal =>
// Convert the value to string to workaround different serializers generate different JSONs for BigDecimal
v.stripTrailingZeros // removes trailing zeros (2001.500000 -> 2001.5, but can introduce scientific notation (600.000 -> 6E+2)
.toPlainString // converts to normal string (6E+2 -> "600")
case v: BigDecimal =>
// Convert the value to string to workaround different serializers generate different JSONs for BigDecimal
new java.math.BigDecimal(v.toString())
.stripTrailingZeros // removes trailing zeros (2001.500000 -> 2001.5, but can introduce scientific notation (600.000 -> 6E+2)
.toPlainString // converts to normal string (6E+2 -> "600")
case a => a.toString
}
}

}

/**
* This class is used for processing Spark Dataset to calculate aggregates / control measures
*/
class MeasurementProcessor(private var measurements: Seq[Measurement]) {
type MeasurementFunction = Dataset[Row] => String
type MeasurementProcessor = (Measurement, MeasurementFunction)

// Assigning measurement function to each measurement
var processors: Seq[MeasurementProcessor] =
measurements.map(m => (m, getMeasurementFunction(m)))

private val valueColumnName: String = "value"

/** The method calculates measurements for each control. */
private[atum] def measureDataset(ds: Dataset[Row]): Seq[Measurement] = {
Expand Down Expand Up @@ -72,80 +158,18 @@ class MeasurementProcessor(private var measurements: Seq[Measurement]) {

/** The method maps string representation of control type to measurement function. */
private def getMeasurementFunction(measurement: Measurement): MeasurementFunction = {

measurement.controlType match {
case Count.value => (ds: Dataset[Row]) => ds.count().toString
case DistinctCount.value => (ds: Dataset[Row]) => {
ds.select(col(measurement.controlCol)).distinct().count().toString
}
case AggregatedTotal.value =>
(ds: Dataset[Row]) => {
val aggCol = sum(col(valueColumnName))
aggregateColumn(ds, measurement.controlCol, aggCol)
}
case AbsAggregatedTotal.value =>
(ds: Dataset[Row]) => {
val aggCol = sum(abs(col(valueColumnName)))
aggregateColumn(ds, measurement.controlCol, aggCol)
}
case HashCrc32.value =>
(ds: Dataset[Row]) => {
val aggColName = ControlMeasureUtils.getTemporaryColumnName(ds)
val v = ds.withColumn(aggColName, crc32(col(measurement.controlCol).cast("String")))
.agg(sum(col(aggColName))).collect()(0)(0)
if (v == null) "" else v.toString
}
case _ =>
Try {
ControlType.withValueName(measurement.controlType)
} match {
case Failure(exception) =>
Atum.log.error(s"Unrecognized control measurement type '${measurement.controlType}'. Available control measurement types are: " +
s"${ControlType.values.mkString(",")}.")
Atum.log.error(exception.getLocalizedMessage)
(_: Dataset[Row]) => "N/A"
}
}

private def workaroundBigDecimalIssues(value: Any): String = {
// If aggregated value is java.math.BigDecimal, convert it to scala.math.BigDecimal
value match {
case v: java.math.BigDecimal =>
// Convert the value to string to workaround different serializers generate different JSONs for BigDecimal
v.stripTrailingZeros // removes trailing zeros (2001.500000 -> 2001.5, but can introduce scientific notation (600.000 -> 6E+2)
.toPlainString // converts to normal string (6E+2 -> "600")
case v: BigDecimal =>
// Convert the value to string to workaround different serializers generate different JSONs for BigDecimal
new java.math.BigDecimal(v.toString())
.stripTrailingZeros // removes trailing zeros (2001.500000 -> 2001.5, but can introduce scientific notation (600.000 -> 6E+2)
.toPlainString // converts to normal string (6E+2 -> "600")
case a => a.toString
case Success(controlType) =>
MeasurementProcessor.getMeasurementFunction(measurement.controlCol, controlType)
}
}

private def aggregateColumn(ds: Dataset[Row], measureColumn: String, aggExpression: Column): String = {
val dataType = ds.select(measureColumn).schema.fields(0).dataType
val aggregatedValue = dataType match {
case _: LongType =>
// This is protection against long overflow, e.g. Long.MaxValue = 9223372036854775807:
// scala> sc.parallelize(List(Long.MaxValue, 1)).toDF.agg(sum("value")).take(1)(0)(0)
// res11: Any = -9223372036854775808
// Converting to BigDecimal fixes the issue
//val ds2 = ds.select(col(measurement.controlCol).cast(DecimalType(38, 0)).as("value"))
//ds2.agg(sum(abs($"value"))).collect()(0)(0)
val ds2 = ds.select(col(measureColumn).cast(DecimalType(38, 0)).as(valueColumnName))
val collected = ds2.agg(aggExpression).collect()(0)(0)
if (collected == null) 0 else collected
case _: StringType =>
// Support for string type aggregation
val ds2 = ds.select(col(measureColumn).cast(DecimalType(38, 18)).as(valueColumnName))
val collected = ds2.agg(aggExpression).collect()(0)(0)
val value = if (collected==null) new java.math.BigDecimal(0) else collected.asInstanceOf[java.math.BigDecimal]
value
.stripTrailingZeros // removes trailing zeros (2001.500000 -> 2001.5, but can introduce scientific notation (600.000 -> 6E+2)
.toPlainString // converts to normal string (6E+2 -> "600")
case _ =>
val ds2 = ds.select(col(measureColumn).as(valueColumnName))
val collected = ds2.agg(aggExpression).collect()(0)(0)
if (collected == null) 0 else collected
}
//check if total is required to be presented as larger type - big decimal
workaroundBigDecimalIssues(aggregatedValue)
}

}
Loading

0 comments on commit fbe44f2

Please sign in to comment.