Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Add Row Level Result Treatment Options for Uniqueness and Completeness #532

Merged
merged 8 commits into from
Feb 15, 2024
3 changes: 1 addition & 2 deletions src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import com.amazon.deequ.repository._
import org.apache.spark.sql.{DataFrame, SparkSession}

/** A class to build a VerificationRun using a fluent API */
class VerificationRunBuilder(val data: DataFrame) {
class VerificationRunBuilder(val data: DataFrame) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: excess


protected var requiredAnalyzers: Seq[Analyzer[_, Metric[_]]] = Seq.empty

Expand Down Expand Up @@ -159,7 +159,6 @@ class VerificationRunBuilder(val data: DataFrame) {
new VerificationRunBuilderWithSparkSession(this, Option(sparkSession))
}


def run(): VerificationResult = {
VerificationSuite().doVerificationRun(
data,
Expand Down
31 changes: 25 additions & 6 deletions src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers._
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.amazon.deequ.analyzers.NullBehavior.NullBehavior
import com.amazon.deequ.analyzers.runners._
import com.amazon.deequ.metrics.DoubleMetric
Expand Down Expand Up @@ -69,7 +70,7 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable {
* @param data data frame
* @return
*/
def computeStateFrom(data: DataFrame): Option[S]
def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None): Option[S]

/**
* Compute the metric from the state (sufficient statistics)
Expand Down Expand Up @@ -97,13 +98,14 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable {
def calculate(
data: DataFrame,
aggregateWith: Option[StateLoader] = None,
saveStatesWith: Option[StatePersister] = None)
saveStatesWith: Option[StatePersister] = None,
filterCondition: Option[String] = None)
: M = {

try {
preconditions.foreach { condition => condition(data.schema) }

val state = computeStateFrom(data)
val state = computeStateFrom(data, filterCondition)

calculateMetric(state, aggregateWith, saveStatesWith)
} catch {
Expand Down Expand Up @@ -170,7 +172,6 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable {
private[deequ] def copyStateTo(source: StateLoader, target: StatePersister): Unit = {
source.load[S](this).foreach { state => target.persist(this, state) }
}

}

/** An analyzer that runs a set of aggregation functions over the data,
Expand All @@ -184,7 +185,7 @@ trait ScanShareableAnalyzer[S <: State[_], +M <: Metric[_]] extends Analyzer[S,
private[deequ] def fromAggregationResult(result: Row, offset: Int): Option[S]

/** Runs aggregation functions directly, without scan sharing */
override def computeStateFrom(data: DataFrame): Option[S] = {
override def computeStateFrom(data: DataFrame, where: Option[String] = None): Option[S] = {
val aggregations = aggregationFunctions()
val result = data.agg(aggregations.head, aggregations.tail: _*).collect().head
fromAggregationResult(result, 0)
Expand Down Expand Up @@ -255,12 +256,18 @@ case class NumMatchesAndCount(numMatches: Long, count: Long, override val fullCo
}
}

case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore)
case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore,
filteredRow: FilteredRow = FilteredRow.TRUE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about filteredRowOutcome or filteredRowEvaluationStatus. AnalyzerOptions is a public facing API, and filteredRow could be confusing for customers.

object NullBehavior extends Enumeration {
type NullBehavior = Value
val Ignore, EmptyString, Fail = Value
}

object FilteredRow extends Enumeration {
type FilteredRow = Value
val NULL, TRUE = Value
}

/** Base class for analyzers that compute ratios of matching predicates */
abstract class PredicateMatchingAnalyzer(
name: String,
Expand Down Expand Up @@ -490,6 +497,18 @@ private[deequ] object Analyzers {
conditionalSelectionFromColumns(selection, conditionColumn)
}

def conditionalSelectionFilteredFromColumns(
selection: Column,
conditionColumn: Option[Column],
filterTreatment: String)
: Column = {
conditionColumn
.map { condition => {
when(not(condition), expr(filterTreatment)).when(condition, selection)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can remove the { after => and its enclosing }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we delegate the expr(filterTreatment) to the parameter of the method? We can update the type of filterTreatment to a type of FilteredRow. The expressions for each type of FilteredRow enumerations can sit inside FilteredRow itself. Right now, we have a .toString which breaks the connection between FilteredRow and this method. Ideally, we want to keep that connection to aid in refactoring and general readability of the code.

} }
.getOrElse(selection)
}

private[this] def conditionalSelectionFromColumns(
selection: Column,
conditionColumn: Option[Column])
Expand Down
22 changes: 18 additions & 4 deletions src/main/scala/com/amazon/deequ/analyzers/Completeness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,21 @@ import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNotNested}
import org.apache.spark.sql.functions.sum
import org.apache.spark.sql.types.{IntegerType, StructType}
import Analyzers._
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.{Column, Row}

/** Completeness is the fraction of non-null values in a column of a DataFrame. */
case class Completeness(column: String, where: Option[String] = None) extends
case class Completeness(column: String, where: Option[String] = None,
analyzerOptions: Option[AnalyzerOptions] = None) extends
StandardScanShareableAnalyzer[NumMatchesAndCount]("Completeness", column) with
FilterableAnalyzer {

override def fromAggregationResult(result: Row, offset: Int): Option[NumMatchesAndCount] = {

ifNoNullsIn(result, offset, howMany = 2) { _ =>
NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(criterion))
NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(rowLevelResults))
}
}

Expand All @@ -51,4 +53,16 @@ case class Completeness(column: String, where: Option[String] = None) extends

@VisibleForTesting // required by some tests that compare analyzer results to an expected state
private[deequ] def criterion: Column = conditionalSelection(column, where).isNotNull

@VisibleForTesting
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this annotation? The method is accessible to classes in com.amazon.deequ which the tests are under.

private[deequ] def rowLevelResults: Column = {
val whereCondition = where.map { expression => expr(expression)}
conditionalSelectionFilteredFromColumns(col(column).isNotNull, whereCondition, getRowLevelFilterTreatment.toString)
}

private def getRowLevelFilterTreatment: FilteredRow = {
analyzerOptions
.map { options => options.filteredRow }
.getOrElse(FilteredRow.TRUE)
}
}
2 changes: 1 addition & 1 deletion src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ case class CustomSql(expression: String) extends Analyzer[CustomSqlState, Double
* @param data data frame
* @return
*/
override def computeStateFrom(data: DataFrame): Option[CustomSqlState] = {
override def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None): Option[CustomSqlState] = {

Try {
data.sqlContext.sql(expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ case class DatasetMatchAnalyzer(dfToCompare: DataFrame,
matchColumnMappings: Option[Map[String, String]] = None)
extends Analyzer[DatasetMatchState, DoubleMetric] {

override def computeStateFrom(data: DataFrame): Option[DatasetMatchState] = {
override def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None): Option[DatasetMatchState] = {

val result = if (matchColumnMappings.isDefined) {
DataSynchronization.columnMatch(data, dfToCompare, columnMappings, matchColumnMappings.get, assertion)
Expand Down
16 changes: 13 additions & 3 deletions src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,17 @@ import org.apache.spark.sql.functions.count
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.functions.when

/** Base class for all analyzers that operate the frequencies of groups in the data */
abstract class FrequencyBasedAnalyzer(columnsToGroupOn: Seq[String])
extends GroupingAnalyzer[FrequenciesAndNumRows, DoubleMetric] {

override def groupingColumns(): Seq[String] = { columnsToGroupOn }

override def computeStateFrom(data: DataFrame): Option[FrequenciesAndNumRows] = {
Some(FrequencyBasedAnalyzer.computeFrequencies(data, groupingColumns()))
override def computeStateFrom(data: DataFrame,
filterCondition: Option[String] = None): Option[FrequenciesAndNumRows] = {
Some(FrequencyBasedAnalyzer.computeFrequencies(data, groupingColumns(), filterCondition))
}

/** We need at least one grouping column, and all specified columns must exist */
Expand Down Expand Up @@ -88,7 +90,15 @@ object FrequencyBasedAnalyzer {
.count()

// Set rows with value count 1 to true, and otherwise false
val fullColumn: Column = count(UNIQUENESS_ID).over(Window.partitionBy(columnsToGroupBy: _*))
val fullColumn: Column = {
val window = Window.partitionBy(columnsToGroupBy: _*)
where.map {
condition => {
count(when(expr(condition), UNIQUENESS_ID)).over(window)
}
Comment on lines +96 to +98
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can remove the brackets after =>

}.getOrElse(count(UNIQUENESS_ID).over(window))
}

FrequenciesAndNumRows(frequencies, numRows, Option(fullColumn))
}

Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/com/amazon/deequ/analyzers/Histogram.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ case class Histogram(
}
}

override def computeStateFrom(data: DataFrame): Option[FrequenciesAndNumRows] = {
override def computeStateFrom(data: DataFrame,
filterCondition: Option[String] = None): Option[FrequenciesAndNumRows] = {

// TODO figure out a way to pass this in if its known before hand
val totalCount = if (computeFrequenciesAsRatio) {
Expand Down
26 changes: 23 additions & 3 deletions src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@
package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers.COUNT_COL
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.amazon.deequ.metrics.DoubleMetric
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.not
import org.apache.spark.sql.functions.when
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.functions.{col, count, lit, sum}
import org.apache.spark.sql.types.DoubleType

case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None)
case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None,
analyzerOptions: Option[AnalyzerOptions] = None)
extends ScanShareableFrequencyBasedAnalyzer("UniqueValueRatio", columns)
with FilterableAnalyzer {

Expand All @@ -34,11 +38,27 @@ case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None)
override def fromAggregationResult(result: Row, offset: Int, fullColumn: Option[Column] = None): DoubleMetric = {
val numUniqueValues = result.getDouble(offset)
val numDistinctValues = result.getLong(offset + 1).toDouble
val fullColumnUniqueness = when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false)
toSuccessMetric(numUniqueValues / numDistinctValues, Option(fullColumnUniqueness))
val conditionColumn = where.map { expression => expr(expression) }
val fullColumnUniqueness = fullColumn.map {
rowLevelColumn => {
conditionColumn.map {
condition => {
when(not(condition), expr(getRowLevelFilterTreatment.toString))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment for expr(getRowLevelFilterTreatment.toString) as above

.when(rowLevelColumn.equalTo(1), true).otherwise(false)
}
}.getOrElse(when(rowLevelColumn.equalTo(1), true).otherwise(false))
}
}
toSuccessMetric(numUniqueValues / numDistinctValues, fullColumnUniqueness)
}

override def filterCondition: Option[String] = where

private def getRowLevelFilterTreatment: FilteredRow = {
analyzerOptions
.map { options => options.filteredRow }
.getOrElse(FilteredRow.TRUE)
}
Comment on lines +57 to +61
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This is repeated in a few places, so could go into the base class.

}

object UniqueValueRatio {
Expand Down
29 changes: 25 additions & 4 deletions src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,52 @@
package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers.COUNT_COL
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.amazon.deequ.metrics.DoubleMetric
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.Column
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.when
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.not
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.functions.sum
import org.apache.spark.sql.types.DoubleType

/** Uniqueness is the fraction of unique values of a column(s), i.e.,
* values that occur exactly once. */
case class Uniqueness(columns: Seq[String], where: Option[String] = None)
case class Uniqueness(columns: Seq[String], where: Option[String] = None,
analyzerOptions: Option[AnalyzerOptions] = None)
extends ScanShareableFrequencyBasedAnalyzer("Uniqueness", columns)
with FilterableAnalyzer {

override def aggregationFunctions(numRows: Long): Seq[Column] = {
(sum(col(COUNT_COL).equalTo(lit(1)).cast(DoubleType)) / numRows) :: Nil
(sum(col(COUNT_COL).equalTo(lit(1)).cast(DoubleType)) / numRows) :: Nil
}

override def fromAggregationResult(result: Row, offset: Int, fullColumn: Option[Column]): DoubleMetric = {
val fullColumnUniqueness = when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false)
super.fromAggregationResult(result, offset, Option(fullColumnUniqueness))
val conditionColumn = where.map { expression => expr(expression) }
val fullColumnUniqueness = fullColumn.map {
rowLevelColumn => {
conditionColumn.map {
condition => {
when(not(condition), expr(getRowLevelFilterTreatment.toString))
.when(rowLevelColumn.equalTo(1), true).otherwise(false)
}
}.getOrElse(when(rowLevelColumn.equalTo(1), true).otherwise(false))
}
}
super.fromAggregationResult(result, offset, fullColumnUniqueness)
}

override def filterCondition: Option[String] = where

private def getRowLevelFilterTreatment: FilteredRow = {
analyzerOptions
.map { options => options.filteredRow }
.getOrElse(FilteredRow.TRUE)
}
}

object Uniqueness {
Expand Down
Loading
Loading