Skip to content

Commit

Permalink
[SPARK-22901][PYTHON] Add deterministic flag to pyspark UDF
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

In SPARK-20586 the flag `deterministic` was added to Scala UDF, but it is not available for python UDF. This flag is useful for cases when the UDF's code can return different result with the same input. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked more times than it is present in the query. This can lead to unexpected behavior.

This PR adds the deterministic flag, via the `asNondeterministic` method, to let the user mark the function as non-deterministic and therefore avoid the optimizations which might lead to strange behaviors.

## How was this patch tested?

Manual tests:
```
>>> from pyspark.sql.functions import *
>>> from pyspark.sql.types import *
>>> df_br = spark.createDataFrame([{'name': 'hello'}])
>>> import random
>>> udf_random_col =  udf(lambda: int(100*random.random()), IntegerType()).asNondeterministic()
>>> df_br = df_br.withColumn('RAND', udf_random_col())
>>> random.seed(1234)
>>> udf_add_ten =  udf(lambda rand: rand + 10, IntegerType())
>>> df_br.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).show()
+-----+----+-------------+
| name|RAND|RAND_PLUS_TEN|
+-----+----+-------------+
|hello|   3|           13|
+-----+----+-------------+

```

Author: Marco Gaido <[email protected]>
Author: Marco Gaido <[email protected]>

Closes #19929 from mgaido91/SPARK-22629.
  • Loading branch information
mgaido91 authored and gatorsmile committed Dec 26, 2017
1 parent eb386be commit ff48b1b
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ private[spark] object PythonEvalType {

val SQL_PANDAS_SCALAR_UDF = 200
val SQL_PANDAS_GROUP_MAP_UDF = 201

def toString(pythonEvalType: Int): String = pythonEvalType match {
case NON_UDF => "NON_UDF"
case SQL_BATCHED_UDF => "SQL_BATCHED_UDF"
case SQL_PANDAS_SCALAR_UDF => "SQL_PANDAS_SCALAR_UDF"
case SQL_PANDAS_GROUP_MAP_UDF => "SQL_PANDAS_GROUP_MAP_UDF"
}
}

/**
Expand Down
11 changes: 8 additions & 3 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2093,9 +2093,14 @@ class PandasUDFType(object):
def udf(f=None, returnType=StringType()):
"""Creates a user defined function (UDF).
.. note:: The user-defined functions must be deterministic. Due to optimization,
duplicate invocations may be eliminated or the function may even be invoked more times than
it is present in the query.
.. note:: The user-defined functions are considered deterministic by default. Due to
optimization, duplicate invocations may be eliminated or the function may even be invoked
more times than it is present in the query. If your function is not deterministic, call
`asNondeterministic` on the user defined function. E.g.:
>>> from pyspark.sql.types import IntegerType
>>> import random
>>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic()
.. note:: The user-defined functions do not support conditional expressions or short curcuiting
in boolean expressions and it ends up with being executed all internally. If the functions
Expand Down
9 changes: 9 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,15 @@ def test_udf_with_array_type(self):
self.assertEqual(list(range(3)), l1)
self.assertEqual(1, l2)

def test_nondeterministic_udf(self):
from pyspark.sql.functions import udf
import random
udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
[row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
self.assertEqual(row[0] + 10, row[1])

def test_broadcast_in_udf(self):
bar = {"a": "aa", "b": "bb", "c": "abc"}
foo = self.sc.broadcast(bar)
Expand Down
13 changes: 12 additions & 1 deletion python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(self, func,
func.__name__ if hasattr(func, '__name__')
else func.__class__.__name__)
self.evalType = evalType
self._deterministic = True

@property
def returnType(self):
Expand Down Expand Up @@ -129,7 +130,7 @@ def _create_judf(self):
wrapped_func = _wrap_function(sc, self.func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
self._name, wrapped_func, jdt, self.evalType)
self._name, wrapped_func, jdt, self.evalType, self._deterministic)
return judf

def __call__(self, *cols):
Expand Down Expand Up @@ -161,5 +162,15 @@ def wrapper(*args):
wrapper.func = self.func
wrapper.returnType = self.returnType
wrapper.evalType = self.evalType
wrapper.asNondeterministic = self.asNondeterministic

return wrapper

def asNondeterministic(self):
"""
Updates UserDefinedFunction to nondeterministic.
.. versionadded:: 2.3
"""
self._deterministic = False
return self
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.TypeTag
import scala.util.Try

import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.internal.Logging
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
Expand All @@ -41,8 +42,6 @@ import org.apache.spark.util.Utils
* spark.udf
* }}}
*
* @note The user-defined functions must be deterministic.
*
* @since 1.3.0
*/
@InterfaceStability.Stable
Expand All @@ -58,6 +57,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
| pythonIncludes: ${udf.func.pythonIncludes}
| pythonExec: ${udf.func.pythonExec}
| dataType: ${udf.dataType}
| pythonEvalType: ${PythonEvalType.toString(udf.pythonEvalType)}
| udfDeterministic: ${udf.udfDeterministic}
""".stripMargin)

functionRegistry.createOrReplaceTempFunction(name, udf.builder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ case class PythonUDF(
func: PythonFunction,
dataType: DataType,
children: Seq[Expression],
evalType: Int)
evalType: Int,
udfDeterministic: Boolean)
extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression {

override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)

override def toString: String = s"$name(${children.mkString(", ")})"

override def nullable: Boolean = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ case class UserDefinedPythonFunction(
name: String,
func: PythonFunction,
dataType: DataType,
pythonEvalType: Int) {
pythonEvalType: Int,
udfDeterministic: Boolean) {

def builder(e: Seq[Expression]): PythonUDF = {
PythonUDF(name, func, dataType, e, pythonEvalType)
PythonUDF(name, func, dataType, e, pythonEvalType, udfDeterministic)
}

/** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,5 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction(
name = "dummyUDF",
func = new DummyUDF,
dataType = BooleanType,
pythonEvalType = PythonEvalType.SQL_BATCHED_UDF)
pythonEvalType = PythonEvalType.SQL_BATCHED_UDF,
udfDeterministic = true)

0 comments on commit ff48b1b

Please sign in to comment.