Skip to content

Commit

Permalink
#97 AggregateControlTypeStrategy suggested API for ControlMeasureBuil…
Browse files Browse the repository at this point in the history
…der usage
  • Loading branch information
dk1844 committed Jul 28, 2021
1 parent a1d99db commit f526eec
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 8 deletions.
2 changes: 1 addition & 1 deletion atum/src/main/scala/za/co/absa/atum/core/ControlType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ object ControlType {

val values = Seq(Count.value, DistinctCount.value, AggregatedTotal.value, AbsAggregatedTotal.value, HashCrc32.value)

def getNormalizedValue(input: String) = {
def getNormalizedValue(input: String): String = {
values.find(value => isControlMeasureTypeEqual(input, value)).getOrElse(input)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ import org.slf4j.LoggerFactory
import za.co.absa.atum.core.ControlType
import za.co.absa.atum.model.CheckpointImplicits.CheckpointExt
import za.co.absa.atum.model.{Checkpoint, ControlMeasure, ControlMeasureMetadata, Measurement}
import za.co.absa.atum.utils.controlmeasure.ControlMeasureBuilder.AggregateControlTypeStrategy
import za.co.absa.atum.utils.controlmeasure.ControlMeasureUtils.getTimestampAsString

import scala.util.Try

trait ControlMeasureBuilder {
def withAggregateColumns(aggregateColumns: Seq[String]): ControlMeasureBuilder
def withAggregateColumns(aggregateColumns: Seq[String],
strategy: AggregateControlTypeStrategy = AggregateControlTypeStrategy.Default): ControlMeasureBuilder
def withSourceApplication(sourceApplication: String): ControlMeasureBuilder
def withInputPath(inputPath: String): ControlMeasureBuilder
def withReportDate(reportDate: String): ControlMeasureBuilder
Expand All @@ -43,6 +45,31 @@ trait ControlMeasureBuilder {


object ControlMeasureBuilder {

sealed trait AggregateControlTypeStrategy
object AggregateControlTypeStrategy {

/**
* For numeric types controlType.absAggregatedTotal and for non-numeric controlType.HashCrc32 is used.
*/
case object Default extends AggregateControlTypeStrategy

/**
* This controlType will be attempted to used fro all aggregateColumns. If unusable
* (e.g. AggregatedTotal or AbsAggregatedTotal for non-numeric, controlType is fallbacked by using the Default.
*
* @param controlType single controlType to be attempted to used for all aggregateColumns
*/
case class Common(controlType: ControlType) extends AggregateControlTypeStrategy

/**
* Specify the concrete control types to be used, in order of the aggregateColumns.
* @param controlTypes sequence to use for aggregateColumns
*/
case class Specific(controlTypes: Seq[ControlType]) extends AggregateControlTypeStrategy

}

/**
* Get builder instance
*
Expand Down Expand Up @@ -74,6 +101,8 @@ object ControlMeasureBuilder {
*/
private case class ControlMeasureBuilderImpl(df: DataFrame,
aggregateColumns: Seq[String] = Seq.empty,
aggregateControlTypeStrategy: AggregateControlTypeStrategy
= AggregateControlTypeStrategy.Default,
sourceApplication: String = "",
inputPathName: String = "",
reportDate: String = ControlMeasureUtils.getTodayAsString,
Expand All @@ -97,7 +126,28 @@ object ControlMeasureBuilder {
def withSourceApplication(sourceApplication: String): ControlMeasureBuilderImpl = this.copy(sourceApplication = sourceApplication)
def withInputPath(inputPath: String): ControlMeasureBuilderImpl = this.copy(inputPathName = inputPath)

def withAggregateColumns(aggregateColumns: Seq[String]): ControlMeasureBuilderImpl = this.copy(aggregateColumns = aggregateColumns)
def withAggregateColumns(aggregateColumns: Seq[String], strategy: AggregateControlTypeStrategy): ControlMeasureBuilderImpl = {
require(aggregateColumns.nonEmpty, "aggregateColumns must not be empty!")

import AggregateControlTypeStrategy._
strategy match {
case Default =>
// numeric check:
case Common(controlType) if controlType == ControlType.AbsAggregatedTotal || controlType == ControlType.AggregatedTotal => {
val nonNumericFields = df.select(aggregateColumns.map(col):_*)
.schema.filter( field => !field.dataType.isInstanceOf[NumericType])
logger.warn(s"Aggregate columns ${nonNumericFields.map(field => s"${field.name} (${field.dataType})").mkString(",")} " +
s"are not numeric, but controlType strategy $strategy was set up. Default strategy will be used instead.")
}
case Specific(controlTypes) =>
if (aggregateColumns.length != controlTypes.length) {
logger.warn(s"AggregateColumns(${aggregateColumns.length}, $aggregateColumns) size does not conform " +
s"the length of list of control types (${controlTypes.length}, $controlTypes). Default strategy might be used.")
}
}

this.copy(aggregateColumns = aggregateColumns, aggregateControlTypeStrategy = strategy)
}
def withReportDate(reportDate: String): ControlMeasureBuilderImpl = {
if (Try(ControlMeasureUtils.dateFormat.parse(reportDate)).isFailure) {
logger.error(s"Report date $reportDate does not validate against format ${ControlMeasureUtils.dateFormat}." +
Expand Down Expand Up @@ -125,7 +175,9 @@ object ControlMeasureBuilder {
calculateMeasurement()
}

def calculateMeasurement(): ControlMeasure = {
def calculateMeasurement(): ControlMeasure = { // scalastyle:off
// todo apply selected aggregateControlTypeStrategy. Default = this:

// Calculate the measurements
val timeStart = getTimestampAsString
val rowCount = df.count()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ControlMeasurementsSpec extends AnyFlatSpec with Matchers with SparkTestBa
)
))

val measurementsIntOferflow = List(
val measurementsIntOverflow = List(
Measurement(
controlName = "RecordCount",
controlType = ControlType.Count.value,
Expand Down Expand Up @@ -102,14 +102,14 @@ class ControlMeasurementsSpec extends AnyFlatSpec with Matchers with SparkTestBa
.schema(schema)
.json(inputDataJson.toDS)

val processor = new MeasurementProcessor(measurementsIntOferflow)
val processor = new MeasurementProcessor(measurementsIntOverflow)
val newMeasurements = processor.measureDataset(df)

println(newMeasurements)

println(measurementsIntOferflow)
println(measurementsIntOverflow)

assert(newMeasurements == measurementsIntOferflow)
assert(newMeasurements == measurementsIntOverflow)
}

val measurementsAggregation = List(
Expand Down

0 comments on commit f526eec

Please sign in to comment.