diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 975536c076ddb..543c60fd352e9 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2425,6 +2425,11 @@ "message" : [ "A higher order function expects arguments, but got ." ] + }, + "UNEVALUABLE" : { + "message" : [ + "Evaluable expressions should be used for a lambda function in a higher order function. However, was unevaluable." + ] } }, "sqlState" : "42K0D" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index bd8f8fe9f6528..43cce38b1a5d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -254,6 +254,13 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB hof.invalidFormat(checkRes) } + case e: HigherOrderFunction + if e.resolved && e.functions.exists(_.exists(_.isInstanceOf[Unevaluable])) => + val u = e.functions.flatMap(_.find(_.isInstanceOf[Unevaluable])).head + e.failAnalysis( + errorClass = "INVALID_LAMBDA_FUNCTION_CALL.UNEVALUABLE", + messageParameters = Map("funcName" -> toSQLExpr(u))) + // If an attribute can't be resolved as a map key of string type, either the key should be // surrounded with single quotes, or there is a typo in the attribute name. case GetMapValue(map, key: Attribute) if isMapWithStringKey(map) && !key.resolved => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala index 3101281251b1b..efbb0e4a70312 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.python -import org.apache.spark.sql.{IntegratedUDFTestUtils, QueryTest} -import org.apache.spark.sql.functions.count +import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest} +import org.apache.spark.sql.functions.{array, count, transform} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.LongType @@ -112,4 +112,16 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { val pandasTestUDF = TestGroupedAggPandasUDF(name = udfName) assert(df.agg(pandasTestUDF(df("id"))).schema.fieldNames.exists(_.startsWith(udfName))) } + + test("SPARK-48706: Negative test case for Python UDF in higher order functions") { + assume(shouldTestPythonUDFs) + checkError( + exception = intercept[AnalysisException] { + spark.range(1).select(transform(array("id"), x => pythonTestUDF(x))).collect() + }, + errorClass = "INVALID_LAMBDA_FUNCTION_CALL.UNEVALUABLE", + parameters = Map("funcName" -> "\"pyUDF(namedlambdavariable())\""), + context = ExpectedContext( + "transform", s".*${this.getClass.getSimpleName}.*")) + } }