Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22939] [PySpark] Support Spark UDF in registerFunction #20137

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions python/pyspark/sql/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,15 +227,15 @@ def dropGlobalTempView(self, viewName):
@ignore_unicode_prefix
@since(2.0)
def registerFunction(self, name, f, returnType=StringType()):
"""Registers a python function (including lambda function) as a UDF
so it can be used in SQL statements.
"""Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
as a UDF. The registered UDF can be used in SQL statement.

In addition to a name and the function itself, the return type can be optionally specified.
When the return type is not given it default to a string and conversion will automatically
be done. For any other return type, the produced object must match the specified type.

:param name: name of the UDF
:param f: python function
:param f: a Python function, or a wrapped/native UserDefinedFunction
:param returnType: a :class:`pyspark.sql.types.DataType` object
:return: a wrapped :class:`UserDefinedFunction`

Expand All @@ -255,9 +255,26 @@ def registerFunction(self, name, f, returnType=StringType()):
>>> _ = spark.udf.register("stringLengthInt", len, IntegerType())
>>> spark.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's fix the doc for this too. It says :param f: python function but we could describe that it takes Python native function, wrapped function and UserDefinedFunction too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


>>> import random
>>> from pyspark.sql.functions import udf
>>> from pyspark.sql.types import IntegerType, StringType
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
>>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType())
>>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP
[Row(random_udf()=u'82')]
>>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP
[Row(random_udf()=u'62')]
"""
udf = UserDefinedFunction(f, returnType=returnType, name=name,
evalType=PythonEvalType.SQL_BATCHED_UDF)

# This is to check whether the input function is a wrapped/native UserDefinedFunction
if hasattr(f, 'asNondeterministic'):
udf = UserDefinedFunction(f.func, returnType=returnType, name=name,
evalType=PythonEvalType.SQL_BATCHED_UDF,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ueshin @icexelloss , shall we support register pandas UDF here too?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems we can support it by just changing evalType=PythonEvalType.SQL_BATCHED_UDF to evalType=f.evalType

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 but I think there's no way to use a group map UDF in SQL syntax if I understood correctly. I think we can safely fail fast for now as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will support the pandas UDF as a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 too

deterministic=f.deterministic)
else:
udf = UserDefinedFunction(f, returnType=returnType, name=name,
evalType=PythonEvalType.SQL_BATCHED_UDF)
self._jsparkSession.udf().registerPython(name, udf._judf)
return udf._wrapped()

Expand Down
16 changes: 13 additions & 3 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,15 @@ def range(self, start, end=None, step=1, numPartitions=None):
@ignore_unicode_prefix
@since(1.2)
def registerFunction(self, name, f, returnType=StringType()):
"""Registers a python function (including lambda function) as a UDF
so it can be used in SQL statements.
"""Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
as a UDF. The registered UDF can be used in SQL statement.

In addition to a name and the function itself, the return type can be optionally specified.
When the return type is not given it default to a string and conversion will automatically
be done. For any other return type, the produced object must match the specified type.

:param name: name of the UDF
:param f: python function
:param f: a Python function, or a wrapped/native UserDefinedFunction
:param returnType: a :class:`pyspark.sql.types.DataType` object
:return: a wrapped :class:`UserDefinedFunction`

Expand All @@ -203,6 +203,16 @@ def registerFunction(self, name, f, returnType=StringType()):
>>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]

>>> import random
>>> from pyspark.sql.functions import udf
>>> from pyspark.sql.types import IntegerType, StringType
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
>>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType())
>>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP
[Row(random_udf()=u'82')]
>>> sqlContext.range(1).select(newRandom_udf()).collect() # doctest: +SKIP
[Row(random_udf()=u'62')]
"""
return self.sparkSession.catalog.registerFunction(name, f, returnType)

Expand Down
49 changes: 35 additions & 14 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,41 @@ def test_udf2(self):
[res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])

def test_udf3(self):
twoargs = self.spark.catalog.registerFunction(
"twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType())
self.assertEqual(twoargs.deterministic, True)
[row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
self.assertEqual(row[0], 5)

def test_nondeterministic_udf(self):
from pyspark.sql.functions import udf
import random
udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
self.assertEqual(udf_random_col.deterministic, False)
df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
[row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
self.assertEqual(row[0] + 10, row[1])

def test_nondeterministic_udf2(self):
import random
from pyspark.sql.functions import udf
random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
self.assertEqual(random_udf.deterministic, False)
random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType())
self.assertEqual(random_udf1.deterministic, False)
[row] = self.spark.sql("SELECT randInt()").collect()
self.assertEqual(row[0], "6")
[row] = self.spark.range(1).select(random_udf1()).collect()
self.assertEqual(row[0], "6")
[row] = self.spark.range(1).select(random_udf()).collect()
self.assertEqual(row[0], 6)
# render_doc() reproduces the help() exception without printing output
pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType()))
pydoc.render_doc(random_udf)
pydoc.render_doc(random_udf1)

def test_chained_udf(self):
self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType())
[row] = self.spark.sql("SELECT double(1)").collect()
Expand Down Expand Up @@ -435,15 +470,6 @@ def test_udf_with_array_type(self):
self.assertEqual(list(range(3)), l1)
self.assertEqual(1, l2)

def test_nondeterministic_udf(self):
from pyspark.sql.functions import udf
import random
udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
[row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
self.assertEqual(row[0] + 10, row[1])

def test_broadcast_in_udf(self):
bar = {"a": "aa", "b": "bb", "c": "abc"}
foo = self.sc.broadcast(bar)
Expand Down Expand Up @@ -567,15 +593,13 @@ def test_read_multiple_orc_file(self):

def test_udf_with_input_file_name(self):
from pyspark.sql.functions import udf, input_file_name
from pyspark.sql.types import StringType
sourceFile = udf(lambda path: path, StringType())
filePath = "python/test_support/sql/people1.json"
row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
self.assertTrue(row[0].find("people1.json") != -1)

def test_udf_with_input_file_name_for_hadooprdd(self):
from pyspark.sql.functions import udf, input_file_name
from pyspark.sql.types import StringType

def filename(path):
return path
Expand Down Expand Up @@ -635,7 +659,6 @@ def test_udf_with_string_return_type(self):

def test_udf_shouldnt_accept_noncallable_object(self):
from pyspark.sql.functions import UserDefinedFunction
from pyspark.sql.types import StringType

non_callable = None
self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())
Expand Down Expand Up @@ -1299,7 +1322,6 @@ def test_between_function(self):
df.filter(df.a.between(df.b, df.c)).collect())

def test_struct_type(self):
from pyspark.sql.types import StructType, StringType, StructField
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
struct2 = StructType([StructField("f1", StringType(), True),
StructField("f2", StringType(), True, None)])
Expand Down Expand Up @@ -1368,7 +1390,6 @@ def test_parse_datatype_string(self):
_parse_datatype_string("a INT, c DOUBLE"))

def test_metadata_null(self):
from pyspark.sql.types import StructType, StringType, StructField
schema = StructType([StructField("f1", StringType(), True, None),
StructField("f2", StringType(), True, {'a': None})])
rdd = self.sc.parallelize([["a", "b"], ["c", "d"]])
Expand Down
21 changes: 14 additions & 7 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def _create_udf(f, returnType, evalType):
)

# Set the name of the UserDefinedFunction object to be the name of function f
udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, evalType=evalType)
udf_obj = UserDefinedFunction(
f, returnType=returnType, name=None, evalType=evalType, deterministic=True)
return udf_obj._wrapped()


Expand All @@ -67,8 +68,10 @@ class UserDefinedFunction(object):
.. versionadded:: 1.3
"""
def __init__(self, func,
returnType=StringType(), name=None,
evalType=PythonEvalType.SQL_BATCHED_UDF):
returnType=StringType(),
name=None,
evalType=PythonEvalType.SQL_BATCHED_UDF,
deterministic=True):
if not callable(func):
raise TypeError(
"Invalid function: not a function or callable (__call__ is not defined): "
Expand All @@ -92,7 +95,7 @@ def __init__(self, func,
func.__name__ if hasattr(func, '__name__')
else func.__class__.__name__)
self.evalType = evalType
self._deterministic = True
self.deterministic = deterministic

@property
def returnType(self):
Expand Down Expand Up @@ -130,14 +133,17 @@ def _create_judf(self):
wrapped_func = _wrap_function(sc, self.func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
self._name, wrapped_func, jdt, self.evalType, self._deterministic)
self._name, wrapped_func, jdt, self.evalType, self.deterministic)
return judf

def __call__(self, *cols):
judf = self._judf
sc = SparkContext._active_spark_context
return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))

# This function is for improving the online help system in the interactive interpreter.
# For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and
# argument annotation. (See: SPARK-19161)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can put this in the docstring of _wrapped between L148 and 150L.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not want to expose these comments to the doc.

def _wrapped(self):
"""
Wrap this udf with a function and attach docstring from func
Expand All @@ -162,7 +168,8 @@ def wrapper(*args):
wrapper.func = self.func
wrapper.returnType = self.returnType
wrapper.evalType = self.evalType
wrapper.asNondeterministic = self.asNondeterministic
wrapper.deterministic = self.deterministic
wrapper.asNondeterministic = lambda: self.asNondeterministic()._wrapped()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do:

       wrapper.asNondeterministic = functools.wraps(
           self.asNondeterministic)(lambda: self.asNondeterministic()._wrapped())

So that it can produce a proper pydoc when we do help(udf(lambda: 1, "integer").asNondeterministic) (not help(udf(lambda: 1, "integer").asNondeterministic()).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good to know the difference

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will leave this unchanged. Maybe you can submit a follow-up PR to address it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely. Will give a try within the following week tho ...


return wrapper

Expand All @@ -172,5 +179,5 @@ def asNondeterministic(self):

.. versionadded:: 2.3
"""
self._deterministic = False
self.deterministic = False
Copy link
Member

@HyukjinKwon HyukjinKwon Jan 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call it udfDeterministic to be consistent with Scala side?

The opposite works fine to me too if that's possible in any way.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deterministic is used in UserDefinedFunction.scala. Users can use it to check whether this UDF is deterministic or not.

return self