Skip to content

Commit

Permalink
[SPARK-44075][CONNECT] Make transformStatCorr lazy
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Make `transformStatCorr` lazy

### Why are the changes needed?
current implementation eagerly compute the result and built the local relation, this computation can be deferred

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
existing CI

Closes apache#41621 from zhengruifeng/connect_lazy_corr.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Jun 16, 2023
1 parent 1b12094 commit 8212d6f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -406,15 +406,15 @@ class SparkConnectPlanner(val session: SparkSession) extends Logging {

private def transformStatCorr(rel: proto.StatCorr): LogicalPlan = {
val df = Dataset.ofRows(session, transformRelation(rel.getInput))
val corr = if (rel.hasMethod) {
df.stat.corr(rel.getCol1, rel.getCol2, rel.getMethod)
if (rel.hasMethod) {
StatFunctions
.calculateCorrImpl(df, Seq(rel.getCol1, rel.getCol2), rel.getMethod)
.logicalPlan
} else {
df.stat.corr(rel.getCol1, rel.getCol2)
StatFunctions
.calculateCorrImpl(df, Seq(rel.getCol1, rel.getCol2))
.logicalPlan
}

LocalRelation.fromProduct(
output = AttributeReference("corr", DoubleType, false)() :: Nil,
data = Tuple1.apply(corr) :: Nil)
}

private def transformStatApproxQuantile(rel: proto.StatApproxQuantile): LogicalPlan = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,18 @@ object StatFunctions extends Logging {

/** Calculate the Pearson Correlation Coefficient for the given columns */
def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
calculateCorrImpl(df, cols).head.getDouble(0)
}

private[sql] def calculateCorrImpl(
df: DataFrame,
cols: Seq[String],
method: String = "pearson"): DataFrame = {
require(method == "pearson", "Currently only the calculation of the Pearson Correlation " +
"coefficient is supported.")
require(cols.length == 2,
"Currently correlation calculation is supported between two columns.")

val Seq(col1, col2) = cols.map { c =>
val dataType = df.resolve(c).dataType
require(dataType.isInstanceOf[NumericType],
Expand All @@ -123,7 +133,8 @@ object StatFunctions extends Logging {
df.select(
when(isnull(correlation), lit(Double.NaN))
.otherwise(correlation)
).head.getDouble(0)
.as("corr")
)
}

/**
Expand Down

0 comments on commit 8212d6f

Please sign in to comment.