diff --git a/atum/src/main/scala/za/co/absa/atum/core/ControlType.scala b/atum/src/main/scala/za/co/absa/atum/core/ControlType.scala index f0ee9029..c8336738 100644 --- a/atum/src/main/scala/za/co/absa/atum/core/ControlType.scala +++ b/atum/src/main/scala/za/co/absa/atum/core/ControlType.scala @@ -25,7 +25,7 @@ object ControlType { val values = Seq(Count.value, DistinctCount.value, AggregatedTotal.value, AbsAggregatedTotal.value, HashCrc32.value) - def getNormalizedValue(input: String) = { + def getNormalizedValue(input: String): String = { values.find(value => isControlMeasureTypeEqual(input, value)).getOrElse(input) } diff --git a/atum/src/main/scala/za/co/absa/atum/utils/controlmeasure/ControlMeasureBuilder.scala b/atum/src/main/scala/za/co/absa/atum/utils/controlmeasure/ControlMeasureBuilder.scala index 9702ab92..8c1bbbe1 100644 --- a/atum/src/main/scala/za/co/absa/atum/utils/controlmeasure/ControlMeasureBuilder.scala +++ b/atum/src/main/scala/za/co/absa/atum/utils/controlmeasure/ControlMeasureBuilder.scala @@ -22,12 +22,14 @@ import org.slf4j.LoggerFactory import za.co.absa.atum.core.ControlType import za.co.absa.atum.model.CheckpointImplicits.CheckpointExt import za.co.absa.atum.model.{Checkpoint, ControlMeasure, ControlMeasureMetadata, Measurement} +import za.co.absa.atum.utils.controlmeasure.ControlMeasureBuilder.AggregateControlTypeStrategy import za.co.absa.atum.utils.controlmeasure.ControlMeasureUtils.getTimestampAsString import scala.util.Try trait ControlMeasureBuilder { - def withAggregateColumns(aggregateColumns: Seq[String]): ControlMeasureBuilder + def withAggregateColumns(aggregateColumns: Seq[String], + strategy: AggregateControlTypeStrategy = AggregateControlTypeStrategy.Default): ControlMeasureBuilder def withSourceApplication(sourceApplication: String): ControlMeasureBuilder def withInputPath(inputPath: String): ControlMeasureBuilder def withReportDate(reportDate: String): ControlMeasureBuilder @@ -43,6 +45,31 @@ trait ControlMeasureBuilder { object ControlMeasureBuilder { + + sealed trait AggregateControlTypeStrategy + object AggregateControlTypeStrategy { + + /** + * For numeric types controlType.absAggregatedTotal and for non-numeric controlType.HashCrc32 is used. + */ + case object Default extends AggregateControlTypeStrategy + + /** + * This controlType will be attempted to used fro all aggregateColumns. If unusable + * (e.g. AggregatedTotal or AbsAggregatedTotal for non-numeric, controlType is fallbacked by using the Default. + * + * @param controlType single controlType to be attempted to used for all aggregateColumns + */ + case class Common(controlType: ControlType) extends AggregateControlTypeStrategy + + /** + * Specify the concrete control types to be used, in order of the aggregateColumns. + * @param controlTypes sequence to use for aggregateColumns + */ + case class Specific(controlTypes: Seq[ControlType]) extends AggregateControlTypeStrategy + + } + /** * Get builder instance * @@ -74,6 +101,8 @@ object ControlMeasureBuilder { */ private case class ControlMeasureBuilderImpl(df: DataFrame, aggregateColumns: Seq[String] = Seq.empty, + aggregateControlTypeStrategy: AggregateControlTypeStrategy + = AggregateControlTypeStrategy.Default, sourceApplication: String = "", inputPathName: String = "", reportDate: String = ControlMeasureUtils.getTodayAsString, @@ -97,7 +126,28 @@ object ControlMeasureBuilder { def withSourceApplication(sourceApplication: String): ControlMeasureBuilderImpl = this.copy(sourceApplication = sourceApplication) def withInputPath(inputPath: String): ControlMeasureBuilderImpl = this.copy(inputPathName = inputPath) - def withAggregateColumns(aggregateColumns: Seq[String]): ControlMeasureBuilderImpl = this.copy(aggregateColumns = aggregateColumns) + def withAggregateColumns(aggregateColumns: Seq[String], strategy: AggregateControlTypeStrategy): ControlMeasureBuilderImpl = { + require(aggregateColumns.nonEmpty, "aggregateColumns must not be empty!") + + import AggregateControlTypeStrategy._ + strategy match { + case Default => + // numeric check: + case Common(controlType) if controlType == ControlType.AbsAggregatedTotal || controlType == ControlType.AggregatedTotal => { + val nonNumericFields = df.select(aggregateColumns.map(col):_*) + .schema.filter( field => !field.dataType.isInstanceOf[NumericType]) + logger.warn(s"Aggregate columns ${nonNumericFields.map(field => s"${field.name} (${field.dataType})").mkString(",")} " + + s"are not numeric, but controlType strategy $strategy was set up. Default strategy will be used instead.") + } + case Specific(controlTypes) => + if (aggregateColumns.length != controlTypes.length) { + logger.warn(s"AggregateColumns(${aggregateColumns.length}, $aggregateColumns) size does not conform " + + s"the length of list of control types (${controlTypes.length}, $controlTypes). Default strategy might be used.") + } + } + + this.copy(aggregateColumns = aggregateColumns, aggregateControlTypeStrategy = strategy) + } def withReportDate(reportDate: String): ControlMeasureBuilderImpl = { if (Try(ControlMeasureUtils.dateFormat.parse(reportDate)).isFailure) { logger.error(s"Report date $reportDate does not validate against format ${ControlMeasureUtils.dateFormat}." + @@ -125,7 +175,9 @@ object ControlMeasureBuilder { calculateMeasurement() } - def calculateMeasurement(): ControlMeasure = { + def calculateMeasurement(): ControlMeasure = { // scalastyle:off + // todo apply selected aggregateControlTypeStrategy. Default = this: + // Calculate the measurements val timeStart = getTimestampAsString val rowCount = df.count() diff --git a/atum/src/test/scala/za/co/absa/atum/ControlMeasurementsSpec.scala b/atum/src/test/scala/za/co/absa/atum/ControlMeasurementsSpec.scala index fab0f8b1..f23dc0c5 100644 --- a/atum/src/test/scala/za/co/absa/atum/ControlMeasurementsSpec.scala +++ b/atum/src/test/scala/za/co/absa/atum/ControlMeasurementsSpec.scala @@ -37,7 +37,7 @@ class ControlMeasurementsSpec extends AnyFlatSpec with Matchers with SparkTestBa ) )) - val measurementsIntOferflow = List( + val measurementsIntOverflow = List( Measurement( controlName = "RecordCount", controlType = ControlType.Count.value, @@ -102,14 +102,14 @@ class ControlMeasurementsSpec extends AnyFlatSpec with Matchers with SparkTestBa .schema(schema) .json(inputDataJson.toDS) - val processor = new MeasurementProcessor(measurementsIntOferflow) + val processor = new MeasurementProcessor(measurementsIntOverflow) val newMeasurements = processor.measureDataset(df) println(newMeasurements) - println(measurementsIntOferflow) + println(measurementsIntOverflow) - assert(newMeasurements == measurementsIntOferflow) + assert(newMeasurements == measurementsIntOverflow) } val measurementsAggregation = List(