diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 706dfe8096f42..494de5350b71c 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -406,15 +406,15 @@ class SparkConnectPlanner(val session: SparkSession) extends Logging { private def transformStatCorr(rel: proto.StatCorr): LogicalPlan = { val df = Dataset.ofRows(session, transformRelation(rel.getInput)) - val corr = if (rel.hasMethod) { - df.stat.corr(rel.getCol1, rel.getCol2, rel.getMethod) + if (rel.hasMethod) { + StatFunctions + .calculateCorrImpl(df, Seq(rel.getCol1, rel.getCol2), rel.getMethod) + .logicalPlan } else { - df.stat.corr(rel.getCol1, rel.getCol2) + StatFunctions + .calculateCorrImpl(df, Seq(rel.getCol1, rel.getCol2)) + .logicalPlan } - - LocalRelation.fromProduct( - output = AttributeReference("corr", DoubleType, false)() :: Nil, - data = Tuple1.apply(corr) :: Nil) } private def transformStatApproxQuantile(rel: proto.StatApproxQuantile): LogicalPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 92d5be1e34c99..dae2b70af78a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -109,8 +109,18 @@ object StatFunctions extends Logging { /** Calculate the Pearson Correlation Coefficient for the given columns */ def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { + calculateCorrImpl(df, cols).head.getDouble(0) + } + + private[sql] def calculateCorrImpl( + df: DataFrame, + cols: Seq[String], + method: String = "pearson"): DataFrame = { + require(method == "pearson", "Currently only the calculation of the Pearson Correlation " + + "coefficient is supported.") require(cols.length == 2, "Currently correlation calculation is supported between two columns.") + val Seq(col1, col2) = cols.map { c => val dataType = df.resolve(c).dataType require(dataType.isInstanceOf[NumericType], @@ -123,7 +133,8 @@ object StatFunctions extends Logging { df.select( when(isnull(correlation), lit(Double.NaN)) .otherwise(correlation) - ).head.getDouble(0) + .as("corr") + ) } /**