Skip to content

Commit

Permalink
Address comments on PR awslabs#532
Browse files Browse the repository at this point in the history
  • Loading branch information
eycho-am committed Feb 19, 2024
1 parent 5b818be commit 7f32da6
Show file tree
Hide file tree
Showing 11 changed files with 32 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import com.amazon.deequ.repository._
import org.apache.spark.sql.{DataFrame, SparkSession}

/** A class to build a VerificationRun using a fluent API */
class VerificationRunBuilder(val data: DataFrame) {
class VerificationRunBuilder(val data: DataFrame) {

protected var requiredAnalyzers: Seq[Analyzer[_, Metric[_]]] = Seq.empty

Expand Down
26 changes: 18 additions & 8 deletions src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers._
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.amazon.deequ.analyzers.FilteredRowOutcome.FilteredRowOutcome
import com.amazon.deequ.analyzers.NullBehavior.NullBehavior
import com.amazon.deequ.analyzers.runners._
import com.amazon.deequ.metrics.DoubleMetric
Expand Down Expand Up @@ -172,6 +172,12 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable {
private[deequ] def copyStateTo(source: StateLoader, target: StatePersister): Unit = {
source.load[S](this).foreach { state => target.persist(this, state) }
}

private[deequ] def getRowLevelFilterTreatment(analyzerOptions: Option[AnalyzerOptions]): FilteredRowOutcome = {
analyzerOptions
.map { options => options.filteredRow }
.getOrElse(FilteredRowOutcome.TRUE)
}
}

/** An analyzer that runs a set of aggregation functions over the data,
Expand Down Expand Up @@ -257,15 +263,19 @@ case class NumMatchesAndCount(numMatches: Long, count: Long, override val fullCo
}

case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore,
filteredRow: FilteredRow = FilteredRow.TRUE)
filteredRow: FilteredRowOutcome = FilteredRowOutcome.TRUE)
object NullBehavior extends Enumeration {
type NullBehavior = Value
val Ignore, EmptyString, Fail = Value
}

object FilteredRow extends Enumeration {
type FilteredRow = Value
object FilteredRowOutcome extends Enumeration {
type FilteredRowOutcome = Value
val NULL, TRUE = Value

implicit class FilteredRowOutcomeOps(value: FilteredRowOutcome) {
def getExpression: Column = expr(value.toString)
}
}

/** Base class for analyzers that compute ratios of matching predicates */
Expand Down Expand Up @@ -500,12 +510,12 @@ private[deequ] object Analyzers {
def conditionalSelectionFilteredFromColumns(
selection: Column,
conditionColumn: Option[Column],
filterTreatment: String)
filterTreatment: FilteredRowOutcome)
: Column = {
conditionColumn
.map { condition => {
when(not(condition), expr(filterTreatment)).when(condition, selection)
} }
.map { condition =>
when(not(condition), filterTreatment.getExpression).when(condition, selection)
}
.getOrElse(selection)
}

Expand Down
10 changes: 1 addition & 9 deletions src/main/scala/com/amazon/deequ/analyzers/Completeness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNotNested}
import org.apache.spark.sql.functions.sum
import org.apache.spark.sql.types.{IntegerType, StructType}
import Analyzers._
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.expr
Expand Down Expand Up @@ -54,15 +53,8 @@ case class Completeness(column: String, where: Option[String] = None,
@VisibleForTesting // required by some tests that compare analyzer results to an expected state
private[deequ] def criterion: Column = conditionalSelection(column, where).isNotNull

@VisibleForTesting
private[deequ] def rowLevelResults: Column = {
val whereCondition = where.map { expression => expr(expression)}
conditionalSelectionFilteredFromColumns(col(column).isNotNull, whereCondition, getRowLevelFilterTreatment.toString)
}

private def getRowLevelFilterTreatment: FilteredRow = {
analyzerOptions
.map { options => options.filteredRow }
.getOrElse(FilteredRow.TRUE)
conditionalSelectionFilteredFromColumns(col(column).isNotNull, whereCondition, getRowLevelFilterTreatment(analyzerOptions))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,8 @@ object FrequencyBasedAnalyzer {
val fullColumn: Column = {
val window = Window.partitionBy(columnsToGroupBy: _*)
where.map {
condition => {
condition =>
count(when(expr(condition), UNIQUENESS_ID)).over(window)
}
}.getOrElse(count(UNIQUENESS_ID).over(window))
}

Expand Down
10 changes: 2 additions & 8 deletions src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers.COUNT_COL
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.amazon.deequ.analyzers.FilteredRowOutcome.FilteredRowOutcome
import com.amazon.deequ.metrics.DoubleMetric
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.not
Expand All @@ -43,7 +43,7 @@ case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None,
rowLevelColumn => {
conditionColumn.map {
condition => {
when(not(condition), expr(getRowLevelFilterTreatment.toString))
when(not(condition), getRowLevelFilterTreatment(analyzerOptions).getExpression)
.when(rowLevelColumn.equalTo(1), true).otherwise(false)
}
}.getOrElse(when(rowLevelColumn.equalTo(1), true).otherwise(false))
Expand All @@ -53,12 +53,6 @@ case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None,
}

override def filterCondition: Option[String] = where

private def getRowLevelFilterTreatment: FilteredRow = {
analyzerOptions
.map { options => options.filteredRow }
.getOrElse(FilteredRow.TRUE)
}
}

object UniqueValueRatio {
Expand Down
10 changes: 2 additions & 8 deletions src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers.COUNT_COL
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.amazon.deequ.analyzers.FilteredRowOutcome.FilteredRowOutcome
import com.amazon.deequ.metrics.DoubleMetric
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.Column
Expand Down Expand Up @@ -47,7 +47,7 @@ case class Uniqueness(columns: Seq[String], where: Option[String] = None,
rowLevelColumn => {
conditionColumn.map {
condition => {
when(not(condition), expr(getRowLevelFilterTreatment.toString))
when(not(condition), getRowLevelFilterTreatment(analyzerOptions).getExpression)
.when(rowLevelColumn.equalTo(1), true).otherwise(false)
}
}.getOrElse(when(rowLevelColumn.equalTo(1), true).otherwise(false))
Expand All @@ -57,12 +57,6 @@ case class Uniqueness(columns: Seq[String], where: Option[String] = None,
}

override def filterCondition: Option[String] = where

private def getRowLevelFilterTreatment: FilteredRow = {
analyzerOptions
.map { options => options.filteredRow }
.getOrElse(FilteredRow.TRUE)
}
}

object Uniqueness {
Expand Down
3 changes: 1 addition & 2 deletions src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec
"generate a result that contains row-level results with null for filtered rows" in withSparkSession { session =>
val data = getDfCompleteAndInCompleteColumns(session)

val analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRow.NULL))
val analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))

val completeness = new Check(CheckLevel.Error, "rule1")
.hasCompleteness("att2", _ > 0.7, None, analyzerOptions)
Expand Down Expand Up @@ -386,7 +386,6 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec

val rowLevel3 = resultData.select(expectedColumn3).collect().map(r => r.getAs[Any](0))
assert(Seq(true, true, null, null, null, null).sameElements(rowLevel3))

}

"generate a result that contains row-level results for null column values" in withSparkSession { session =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class CompletenessTest extends AnyWordSpec with Matchers with SparkContextSpec w

// Explicitly setting RowLevelFilterTreatment for test purposes, this should be set at the VerificationRunBuilder
val completenessAtt2 = Completeness("att2", Option("att1 = \"a\""),
Option(AnalyzerOptions(filteredRow = FilteredRow.NULL)))
Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL)))
val state = completenessAtt2.computeStateFrom(data)
val metric: DoubleMetric with FullColumn = completenessAtt2.computeMetricFrom(state)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class UniquenessTest extends AnyWordSpec with Matchers with SparkContextSpec wit
val data = getDfWithUniqueColumns(session)

val addressLength = Uniqueness(Seq("onlyUniqueWithOtherNonUnique"), Option("unique < 4"),
Option(AnalyzerOptions(filteredRow = FilteredRow.NULL)))
Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL)))
val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique < 4"))
val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state)

Expand All @@ -139,7 +139,7 @@ class UniquenessTest extends AnyWordSpec with Matchers with SparkContextSpec wit
val data = getDfWithUniqueColumns(session)

val addressLength = Uniqueness(Seq("halfUniqueCombinedWithNonUnique", "nonUnique"), Option("unique > 2"),
Option(AnalyzerOptions(filteredRow = FilteredRow.NULL)))
Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL)))
val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique > 2"))
val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,9 @@ class AnalysisRunnerTests extends AnyWordSpec
// Used to be tested with the above line, but adding filters changed the order of the results.
assert(separateResults.asInstanceOf[Set[DoubleMetric]].size ==
runnerResults.asInstanceOf[Set[DoubleMetric]].size)
separateResults.asInstanceOf[Set[DoubleMetric]].foreach( result => {
assert(runnerResults.toString.contains(result.toString))
}
)
separateResults.asInstanceOf[Set[DoubleMetric]].foreach( result =>
assert(runnerResults.toString.contains(result.toString))
)
}

"reuse existing results" in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,5 @@ class AnalyzerContextTest extends AnyWordSpec

private[this] def assertSameJson(jsonA: String, jsonB: String): Unit = {
assert(SimpleResultSerde.deserialize(jsonA).toSet.sameElements(SimpleResultSerde.deserialize(jsonB).toSet))
// assert(SimpleResultSerde.deserialize(jsonA) ==
// SimpleResultSerde.deserialize(jsonB))
}
}

0 comments on commit 7f32da6

Please sign in to comment.