Skip to content

Commit

Permalink
Python UDF in higher order functions should not throw internal error
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed Jun 25, 2024
1 parent 2ac2710 commit 28f4da4
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 2 deletions.
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -2425,6 +2425,11 @@
"message" : [
"A higher order function expects <expectedNumArgs> arguments, but got <actualNumArgs>."
]
},
"UNEVALUABLE" : {
"message" : [
"Evaluable expressions should be used for a lambda function in a higher order function. However, <funcName> was unevaluable."
]
}
},
"sqlState" : "42K0D"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,13 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
hof.invalidFormat(checkRes)
}

case e: HigherOrderFunction
if e.resolved && e.functions.exists(_.exists(_.isInstanceOf[Unevaluable])) =>
val u = e.functions.flatMap(_.find(_.isInstanceOf[Unevaluable])).head
e.failAnalysis(
errorClass = "INVALID_LAMBDA_FUNCTION_CALL.UNEVALUABLE",
messageParameters = Map("funcName" -> toSQLExpr(u)))

// If an attribute can't be resolved as a map key of string type, either the key should be
// surrounded with single quotes, or there is a typo in the attribute name.
case GetMapValue(map, key: Attribute) if isMapWithStringKey(map) && !key.resolved =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

package org.apache.spark.sql.execution.python

import org.apache.spark.sql.{IntegratedUDFTestUtils, QueryTest}
import org.apache.spark.sql.functions.count
import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest}
import org.apache.spark.sql.functions.{array, count, transform}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.LongType

Expand Down Expand Up @@ -112,4 +112,16 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession {
val pandasTestUDF = TestGroupedAggPandasUDF(name = udfName)
assert(df.agg(pandasTestUDF(df("id"))).schema.fieldNames.exists(_.startsWith(udfName)))
}

test("SPARK-48706: Negative test case for Python UDF in higher order functions") {
assume(shouldTestPythonUDFs)
checkError(
exception = intercept[AnalysisException] {
spark.range(1).select(transform(array("id"), x => pythonTestUDF(x))).collect()
},
errorClass = "INVALID_LAMBDA_FUNCTION_CALL.UNEVALUABLE",
parameters = Map("funcName" -> "\"pyUDF(namedlambdavariable())\""),
context = ExpectedContext(
"transform", s".*${this.getClass.getSimpleName}.*"))
}
}

0 comments on commit 28f4da4

Please sign in to comment.