Skip to content

Commit

Permalink
#97 Aggregate control type strategy (#107)
Browse files Browse the repository at this point in the history
* #97 AggregateControlTypeStrategy suggested API for ControlMeasureBuilder usage

* #97 ControlMeasureBuilder.withAggregateColumn(s) implementations. Todo needs testing, cleanup & documentation update
MeasurementProcessor split into object/class to offer generic processing methods to be reusable.

* #97 ControlMeasureBuilder.withAggregateColumn(s) unit tests added (regression guard)

* #97 ControlMeasureBuilder.withAggregateColumn(s) in README.md, original only-default `cmBuilder.calculateMeasurement` removed
  • Loading branch information
dk1844 authored Aug 12, 2021
1 parent a1d99db commit 7362a73
Show file tree
Hide file tree
Showing 7 changed files with 361 additions and 177 deletions.
47 changes: 32 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,40 +111,57 @@ serve as a reference implementation for calculating control measurements.

#### Obtaining a ControlMeasure
The builder instance obtained by `ControlMeasureBuilder.forDf()` accepts some metadata via optional setters.
In addition it accepts the list of fields for which control measurements should be generated. Depending on the data type
of a field the method will generate a different control measurement. For numeric types it will generate
**controlType.absAggregatedTotal**, e.g. **SUM(ABS(X))**. For non-numeric types it will generate
**controlType.HashCrc32** e.g. **SUM(CRC32(x))**. Non-primitive data types are not supported.
In addition it accepts the list of fields for which control measurements should be generated. There are multiple ways
to define these column settings and the type of measurement to be computed is also possible to configure:

```scala
import za.co.absa.atum.utils.controlmeasure.ControlMeasureBuilder
import ControlMeasureBuilder.ControlTypeStrategy.{Default, Specific}
import za.co.absa.atum.core.ControlType.{Count, DistinctCount, AggregatedTotal, AbsAggregatedTotal, HashCrc32}

// controlMeasureBuilder obtainable by ControlMeasureBuilder.forDf(df)

// with Default, the ControlType will be chosen based on the field type (AbsAggregatedTotal for numeric, HashCrc32 otherwise)
val updatedBuilder1 = controlMeasureBuilder.withAggregateColumns(Seq("col1", "col2")) // equivalent to .withAggregateColumns(Seq("col1", "col2"), Default)
val updatedBuilder2 = controlMeasureBuilder.withAggregateColumns(Seq("col1", "col2"), Specific(HashCrc32)) // all columns will use HashCrc32

val iterativelyUpdatedBuilder3 = controlMeasureBuilder
.withAggregateColumn("col1") // equivalent to .withAggregateColumn("col1", Default). AbsAggregatedTotal used if col1 is numeric, HashCrc32 otherwise
.withAggregateColumn("col2", Specific(DistinctCount)) // DistinctCount for this column's measurement

```
The above excerpt demonstrate that the aggregate columns can be either inputted at once with `.withAggregateColumns`
(subsequent calls would replace the columns already defined) with limited `ControlType` strategy (all `Default` or all
single common `Specific` `ControlType`) or using more fine-grained `.withAggregateColumn` where the control type strategy
can be specified for each column in the input. (subsequent calls add to the group).

The default `Default` ControlType strategy will select `ControlType` `AbsAggregatedTotal` (**SUM(ABS(X))**) for numeric fields and
`HashCrc32` (**SUM(CRC32(x))**) for non-numeric ones. Non-primitive data types are not supported.

A full example of initial control measure generation then could look as follows:
```scala
import org.apache.spark.sql.{DataFrame, SparkSession}
import za.co.absa.atum.model.ControlMeasure
import za.co.absa.atum.utils.controlmeasure.ControlMeasureBuilder

val dataSourceName = "Source Application"
val inputPath = "/path/to/source"
val batchDate = "15-10-2017"
val batchVersion = 1

val spark = SparkSession.builder()
.appName("An info file creation job")
.getOrCreate()

val df: DataFrame = spark
.read
.format("csv").option("header", "true") // adjust to your data source format
.load(inputPath)
.load("path/to/source")
val aggregateColumns = List("employeeId", "address", "dealId") // these columns must exist in the `df`

// builder-like fluent API to construct a ControlMeasureBuilder and yield the `controlMeasure` with `build`
val controlMeasure: ControlMeasure =
ControlMeasureBuilder.forDf(df)
.withAggregateColumns(aggregateColumns)
.withInputPath(inputPath)
.withSourceApplication(dataSourceName)
.withReportDate(batchDate)
.withReportVersion(batchVersion)
.withAggregateColumns(aggregateColumns) // using Default controlType strategy: AbsAggregatedTotal for numeric fields, HashCrc32 otherwise
.withInputPath("path/to/source")
.withSourceApplication("Source Application")
.withReportDate("15-10-2017")
.withReportVersion(1)
.build

// convert to JSON using .asJson | asJsonPretty
Expand Down
22 changes: 13 additions & 9 deletions atum/src/main/scala/za/co/absa/atum/core/ControlType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,24 @@

package za.co.absa.atum.core

class ControlType(val value: String)
class ControlType(val value: String, val onlyForNumeric: Boolean)
object ControlType {
case object Count extends ControlType("count")
case object DistinctCount extends ControlType("distinctCount")
case object AggregatedTotal extends ControlType("aggregatedTotal")
case object AbsAggregatedTotal extends ControlType("absAggregatedTotal")
case object HashCrc32 extends ControlType("hashCrc32")
case object Count extends ControlType("count", false)
case object DistinctCount extends ControlType("distinctCount", false)
case object AggregatedTotal extends ControlType("aggregatedTotal", true)
case object AbsAggregatedTotal extends ControlType("absAggregatedTotal", true)
case object HashCrc32 extends ControlType("hashCrc32", false)

val values = Seq(Count.value, DistinctCount.value, AggregatedTotal.value, AbsAggregatedTotal.value, HashCrc32.value)
val values: Seq[ControlType] = Seq(Count, DistinctCount, AggregatedTotal, AbsAggregatedTotal, HashCrc32)
val valueNames: Seq[String] = values.map(_.value)

def getNormalizedValue(input: String) = {
values.find(value => isControlMeasureTypeEqual(input, value)).getOrElse(input)
def getNormalizedValueName(input: String): String = {
valueNames.find(value => isControlMeasureTypeEqual(input, value)).getOrElse(input)
}

def withValueName(s: String): ControlType = values.find(_.value.toString == s).getOrElse(
throw new NoSuchElementException(s"No value found for '$s'. Allowed values are: $valueNames"))

def isControlMeasureTypeEqual(x: String, y: String): Boolean = {
if (x.toLowerCase == y.toLowerCase) {
true
Expand Down
169 changes: 97 additions & 72 deletions atum/src/main/scala/za/co/absa/atum/core/MeasurementProcessor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,110 @@ package za.co.absa.atum.core

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DecimalType, LongType, StringType}
import org.apache.spark.sql.{Column, Dataset, Row}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import za.co.absa.atum.core.ControlType._
import za.co.absa.atum.core.MeasurementProcessor.MeasurementFunction
import za.co.absa.atum.model.Measurement
import za.co.absa.atum.utils.controlmeasure.ControlMeasureUtils

import scala.util.{Failure, Success, Try}

object MeasurementProcessor {
type MeasurementFunction = DataFrame => String

private val valueColumnName: String = "value"

/**
* Assembles the measurement function based for controlCol based on the controlType
* @param controlCol
* @param controlType
* @return
*/
private[atum] def getMeasurementFunction(controlCol: String, controlType: ControlType): MeasurementFunction = {
controlType match {
case Count => (ds: Dataset[Row]) =>
ds.select(col(controlCol)).count().toString
case DistinctCount => (ds: Dataset[Row]) => {
ds.select(col(controlCol)).distinct().count().toString
}
case AggregatedTotal =>
(ds: Dataset[Row]) => {
val aggCol = sum(col(valueColumnName))
aggregateColumn(ds, controlCol, aggCol)
}
case AbsAggregatedTotal =>
(ds: Dataset[Row]) => {
val aggCol = sum(abs(col(valueColumnName)))
aggregateColumn(ds, controlCol, aggCol)
}
case HashCrc32 =>
(ds: Dataset[Row]) => {
val aggColName = ControlMeasureUtils.getTemporaryColumnName(ds)
val v = ds.withColumn(aggColName, crc32(col(controlCol).cast("String")))
.agg(sum(col(aggColName))).collect()(0)(0)
if (v == null) "" else v.toString
}
}
}

private def aggregateColumn(ds: Dataset[Row], measureColumn: String, aggExpression: Column): String = {
val dataType = ds.select(measureColumn).schema.fields(0).dataType
val aggregatedValue = dataType match {
case _: LongType =>
// This is protection against long overflow, e.g. Long.MaxValue = 9223372036854775807:
// scala> sc.parallelize(List(Long.MaxValue, 1)).toDF.agg(sum("value")).take(1)(0)(0)
// res11: Any = -9223372036854775808
// Converting to BigDecimal fixes the issue
//val ds2 = ds.select(col(measurement.controlCol).cast(DecimalType(38, 0)).as("value"))
//ds2.agg(sum(abs($"value"))).collect()(0)(0)
val ds2 = ds.select(col(measureColumn).cast(DecimalType(38, 0)).as(valueColumnName))
val collected = ds2.agg(aggExpression).collect()(0)(0)
if (collected == null) 0 else collected
case _: StringType =>
// Support for string type aggregation
val ds2 = ds.select(col(measureColumn).cast(DecimalType(38, 18)).as(valueColumnName))
val collected = ds2.agg(aggExpression).collect()(0)(0)
val value = if (collected==null) new java.math.BigDecimal(0) else collected.asInstanceOf[java.math.BigDecimal]
value
.stripTrailingZeros // removes trailing zeros (2001.500000 -> 2001.5, but can introduce scientific notation (600.000 -> 6E+2)
.toPlainString // converts to normal string (6E+2 -> "600")
case _ =>
val ds2 = ds.select(col(measureColumn).as(valueColumnName))
val collected = ds2.agg(aggExpression).collect()(0)(0)
if (collected == null) 0 else collected
}
//check if total is required to be presented as larger type - big decimal
workaroundBigDecimalIssues(aggregatedValue)
}

private def workaroundBigDecimalIssues(value: Any): String = {
// If aggregated value is java.math.BigDecimal, convert it to scala.math.BigDecimal
value match {
case v: java.math.BigDecimal =>
// Convert the value to string to workaround different serializers generate different JSONs for BigDecimal
v.stripTrailingZeros // removes trailing zeros (2001.500000 -> 2001.5, but can introduce scientific notation (600.000 -> 6E+2)
.toPlainString // converts to normal string (6E+2 -> "600")
case v: BigDecimal =>
// Convert the value to string to workaround different serializers generate different JSONs for BigDecimal
new java.math.BigDecimal(v.toString())
.stripTrailingZeros // removes trailing zeros (2001.500000 -> 2001.5, but can introduce scientific notation (600.000 -> 6E+2)
.toPlainString // converts to normal string (6E+2 -> "600")
case a => a.toString
}
}

}

/**
* This class is used for processing Spark Dataset to calculate aggregates / control measures
*/
class MeasurementProcessor(private var measurements: Seq[Measurement]) {
type MeasurementFunction = Dataset[Row] => String
type MeasurementProcessor = (Measurement, MeasurementFunction)

// Assigning measurement function to each measurement
var processors: Seq[MeasurementProcessor] =
measurements.map(m => (m, getMeasurementFunction(m)))

private val valueColumnName: String = "value"

/** The method calculates measurements for each control. */
private[atum] def measureDataset(ds: Dataset[Row]): Seq[Measurement] = {
Expand Down Expand Up @@ -72,80 +159,18 @@ class MeasurementProcessor(private var measurements: Seq[Measurement]) {

/** The method maps string representation of control type to measurement function. */
private def getMeasurementFunction(measurement: Measurement): MeasurementFunction = {

measurement.controlType match {
case Count.value => (ds: Dataset[Row]) => ds.count().toString
case DistinctCount.value => (ds: Dataset[Row]) => {
ds.select(col(measurement.controlCol)).distinct().count().toString
}
case AggregatedTotal.value =>
(ds: Dataset[Row]) => {
val aggCol = sum(col(valueColumnName))
aggregateColumn(ds, measurement.controlCol, aggCol)
}
case AbsAggregatedTotal.value =>
(ds: Dataset[Row]) => {
val aggCol = sum(abs(col(valueColumnName)))
aggregateColumn(ds, measurement.controlCol, aggCol)
}
case HashCrc32.value =>
(ds: Dataset[Row]) => {
val aggColName = ControlMeasureUtils.getTemporaryColumnName(ds)
val v = ds.withColumn(aggColName, crc32(col(measurement.controlCol).cast("String")))
.agg(sum(col(aggColName))).collect()(0)(0)
if (v == null) "" else v.toString
}
case _ =>
Try {
ControlType.withValueName(measurement.controlType)
} match {
case Failure(exception) =>
Atum.log.error(s"Unrecognized control measurement type '${measurement.controlType}'. Available control measurement types are: " +
s"${ControlType.values.mkString(",")}.")
Atum.log.error(exception.getLocalizedMessage)
(_: Dataset[Row]) => "N/A"
}
}

private def workaroundBigDecimalIssues(value: Any): String = {
// If aggregated value is java.math.BigDecimal, convert it to scala.math.BigDecimal
value match {
case v: java.math.BigDecimal =>
// Convert the value to string to workaround different serializers generate different JSONs for BigDecimal
v.stripTrailingZeros // removes trailing zeros (2001.500000 -> 2001.5, but can introduce scientific notation (600.000 -> 6E+2)
.toPlainString // converts to normal string (6E+2 -> "600")
case v: BigDecimal =>
// Convert the value to string to workaround different serializers generate different JSONs for BigDecimal
new java.math.BigDecimal(v.toString())
.stripTrailingZeros // removes trailing zeros (2001.500000 -> 2001.5, but can introduce scientific notation (600.000 -> 6E+2)
.toPlainString // converts to normal string (6E+2 -> "600")
case a => a.toString
case Success(controlType) =>
MeasurementProcessor.getMeasurementFunction(measurement.controlCol, controlType)
}
}

private def aggregateColumn(ds: Dataset[Row], measureColumn: String, aggExpression: Column): String = {
val dataType = ds.select(measureColumn).schema.fields(0).dataType
val aggregatedValue = dataType match {
case _: LongType =>
// This is protection against long overflow, e.g. Long.MaxValue = 9223372036854775807:
// scala> sc.parallelize(List(Long.MaxValue, 1)).toDF.agg(sum("value")).take(1)(0)(0)
// res11: Any = -9223372036854775808
// Converting to BigDecimal fixes the issue
//val ds2 = ds.select(col(measurement.controlCol).cast(DecimalType(38, 0)).as("value"))
//ds2.agg(sum(abs($"value"))).collect()(0)(0)
val ds2 = ds.select(col(measureColumn).cast(DecimalType(38, 0)).as(valueColumnName))
val collected = ds2.agg(aggExpression).collect()(0)(0)
if (collected == null) 0 else collected
case _: StringType =>
// Support for string type aggregation
val ds2 = ds.select(col(measureColumn).cast(DecimalType(38, 18)).as(valueColumnName))
val collected = ds2.agg(aggExpression).collect()(0)(0)
val value = if (collected==null) new java.math.BigDecimal(0) else collected.asInstanceOf[java.math.BigDecimal]
value
.stripTrailingZeros // removes trailing zeros (2001.500000 -> 2001.5, but can introduce scientific notation (600.000 -> 6E+2)
.toPlainString // converts to normal string (6E+2 -> "600")
case _ =>
val ds2 = ds.select(col(measureColumn).as(valueColumnName))
val collected = ds2.agg(aggExpression).collect()(0)(0)
if (collected == null) 0 else collected
}
//check if total is required to be presented as larger type - big decimal
workaroundBigDecimalIssues(aggregatedValue)
}

}
Loading

0 comments on commit 7362a73

Please sign in to comment.