Skip to content

Commit

Permalink
[SPARK-23122][PYTHON][SQL] Deprecate register* for UDFs in SQLContext…
Browse files Browse the repository at this point in the history
… and Catalog in PySpark

## What changes were proposed in this pull request?

This PR proposes to deprecate `register*` for UDFs in `SQLContext` and `Catalog` in Spark 2.3.0.

These are inconsistent with Scala / Java APIs and also these basically do the same things with `spark.udf.register*`.

Also, this PR moves the logcis from `[sqlContext|spark.catalog].register*` to `spark.udf.register*` and reuse the docstring.

This PR also handles minor doc corrections. It also includes #20158

## How was this patch tested?

Manually tested, manually checked the API documentation and tests added to check if deprecated APIs call the aliases correctly.

Author: hyukjinkwon <[email protected]>

Closes #20288 from HyukjinKwon/deprecate-udf.

(cherry picked from commit 39d244d)
Signed-off-by: Takuya UESHIN <[email protected]>
  • Loading branch information
HyukjinKwon authored and ueshin committed Jan 18, 2018
1 parent f2688ef commit 3a80cc5
Show file tree
Hide file tree
Showing 8 changed files with 234 additions and 210 deletions.
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def __hash__(self):
"pyspark.sql.functions",
"pyspark.sql.readwriter",
"pyspark.sql.streaming",
"pyspark.sql.udf",
"pyspark.sql.window",
"pyspark.sql.tests",
]
Expand Down
91 changes: 8 additions & 83 deletions python/pyspark/sql/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,92 +224,17 @@ def dropGlobalTempView(self, viewName):
"""
self._jcatalog.dropGlobalTempView(viewName)

@ignore_unicode_prefix
@since(2.0)
def registerFunction(self, name, f, returnType=None):
"""Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
as a UDF. The registered UDF can be used in SQL statements.
:func:`spark.udf.register` is an alias for :func:`spark.catalog.registerFunction`.
In addition to a name and the function itself, `returnType` can be optionally specified.
1) When f is a Python function, `returnType` defaults to a string. The produced object must
match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return
type of the given UDF as the return type of the registered UDF. The input parameter
`returnType` is None by default. If given by users, the value must be None.
:param name: name of the UDF in SQL statements.
:param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either
row-at-a-time or vectorized.
:param returnType: the return type of the registered UDF.
:return: a wrapped/native :class:`UserDefinedFunction`
>>> strlen = spark.catalog.registerFunction("stringLengthString", len)
>>> spark.sql("SELECT stringLengthString('test')").collect()
[Row(stringLengthString(test)=u'4')]
>>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
[Row(stringLengthString(text)=u'3')]
>>> from pyspark.sql.types import IntegerType
>>> _ = spark.catalog.registerFunction("stringLengthInt", len, IntegerType())
>>> spark.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
>>> from pyspark.sql.types import IntegerType
>>> _ = spark.udf.register("stringLengthInt", len, IntegerType())
>>> spark.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
>>> from pyspark.sql.types import IntegerType
>>> from pyspark.sql.functions import udf
>>> slen = udf(lambda s: len(s), IntegerType())
>>> _ = spark.udf.register("slen", slen)
>>> spark.sql("SELECT slen('test')").collect()
[Row(slen(test)=4)]
>>> import random
>>> from pyspark.sql.functions import udf
>>> from pyspark.sql.types import IntegerType
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
>>> new_random_udf = spark.catalog.registerFunction("random_udf", random_udf)
>>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP
[Row(random_udf()=82)]
>>> spark.range(1).select(new_random_udf()).collect() # doctest: +SKIP
[Row(<lambda>()=26)]
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP
... def add_one(x):
... return x + 1
...
>>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP
>>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP
[Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
"""
"""An alias for :func:`spark.udf.register`.
See :meth:`pyspark.sql.UDFRegistration.register`.
# This is to check whether the input function is a wrapped/native UserDefinedFunction
if hasattr(f, 'asNondeterministic'):
if returnType is not None:
raise TypeError(
"Invalid returnType: None is expected when f is a UserDefinedFunction, "
"but got %s." % returnType)
if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
PythonEvalType.SQL_PANDAS_SCALAR_UDF]:
raise ValueError(
"Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF")
register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name,
evalType=f.evalType,
deterministic=f.deterministic)
return_udf = f
else:
if returnType is None:
returnType = StringType()
register_udf = UserDefinedFunction(f, returnType=returnType, name=name,
evalType=PythonEvalType.SQL_BATCHED_UDF)
return_udf = register_udf._wrapped()
self._jsparkSession.udf().registerPython(name, register_udf._judf)
return return_udf
.. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead.
"""
warnings.warn(
"Deprecated in 2.3.0. Use spark.udf.register instead.",
DeprecationWarning)
return self._sparkSession.udf.register(name, f, returnType)

@since(2.0)
def isCached(self, tableName):
Expand Down
137 changes: 18 additions & 119 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.streaming import DataStreamReader
from pyspark.sql.types import IntegerType, Row, StringType
from pyspark.sql.udf import UDFRegistration
from pyspark.sql.utils import install_exception_handler

__all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
__all__ = ["SQLContext", "HiveContext"]


class SQLContext(object):
Expand Down Expand Up @@ -147,7 +148,7 @@ def udf(self):
:return: :class:`UDFRegistration`
"""
return UDFRegistration(self)
return self.sparkSession.udf

@since(1.4)
def range(self, start, end=None, step=1, numPartitions=None):
Expand All @@ -172,113 +173,29 @@ def range(self, start, end=None, step=1, numPartitions=None):
"""
return self.sparkSession.range(start, end, step, numPartitions)

@ignore_unicode_prefix
@since(1.2)
def registerFunction(self, name, f, returnType=None):
"""Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
as a UDF. The registered UDF can be used in SQL statements.
:func:`spark.udf.register` is an alias for :func:`sqlContext.registerFunction`.
In addition to a name and the function itself, `returnType` can be optionally specified.
1) When f is a Python function, `returnType` defaults to a string. The produced object must
match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return
type of the given UDF as the return type of the registered UDF. The input parameter
`returnType` is None by default. If given by users, the value must be None.
:param name: name of the UDF in SQL statements.
:param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either
row-at-a-time or vectorized.
:param returnType: the return type of the registered UDF.
:return: a wrapped/native :class:`UserDefinedFunction`
>>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x))
>>> sqlContext.sql("SELECT stringLengthString('test')").collect()
[Row(stringLengthString(test)=u'4')]
>>> sqlContext.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
[Row(stringLengthString(text)=u'3')]
>>> from pyspark.sql.types import IntegerType
>>> _ = sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
>>> from pyspark.sql.types import IntegerType
>>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
>>> from pyspark.sql.types import IntegerType
>>> from pyspark.sql.functions import udf
>>> slen = udf(lambda s: len(s), IntegerType())
>>> _ = sqlContext.udf.register("slen", slen)
>>> sqlContext.sql("SELECT slen('test')").collect()
[Row(slen(test)=4)]
>>> import random
>>> from pyspark.sql.functions import udf
>>> from pyspark.sql.types import IntegerType
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
>>> new_random_udf = sqlContext.registerFunction("random_udf", random_udf)
>>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP
[Row(random_udf()=82)]
>>> sqlContext.range(1).select(new_random_udf()).collect() # doctest: +SKIP
[Row(<lambda>()=26)]
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP
... def add_one(x):
... return x + 1
...
>>> _ = sqlContext.udf.register("add_one", add_one) # doctest: +SKIP
>>> sqlContext.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP
[Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
"""An alias for :func:`spark.udf.register`.
See :meth:`pyspark.sql.UDFRegistration.register`.
.. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead.
"""
return self.sparkSession.catalog.registerFunction(name, f, returnType)
warnings.warn(
"Deprecated in 2.3.0. Use spark.udf.register instead.",
DeprecationWarning)
return self.sparkSession.udf.register(name, f, returnType)

@ignore_unicode_prefix
@since(2.1)
def registerJavaFunction(self, name, javaClassName, returnType=None):
"""Register a java UDF so it can be used in SQL statements.
In addition to a name and the function itself, the return type can be optionally specified.
When the return type is not specified we would infer it via reflection.
:param name: name of the UDF
:param javaClassName: fully qualified name of java class
:param returnType: a :class:`pyspark.sql.types.DataType` object
>>> sqlContext.registerJavaFunction("javaStringLength",
... "test.org.apache.spark.sql.JavaStringLength", IntegerType())
>>> sqlContext.sql("SELECT javaStringLength('test')").collect()
[Row(UDF:javaStringLength(test)=4)]
>>> sqlContext.registerJavaFunction("javaStringLength2",
... "test.org.apache.spark.sql.JavaStringLength")
>>> sqlContext.sql("SELECT javaStringLength2('test')").collect()
[Row(UDF:javaStringLength2(test)=4)]
"""An alias for :func:`spark.udf.registerJavaFunction`.
See :meth:`pyspark.sql.UDFRegistration.registerJavaFunction`.
.. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaFunction` instead.
"""
jdt = None
if returnType is not None:
jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)

@ignore_unicode_prefix
@since(2.3)
def registerJavaUDAF(self, name, javaClassName):
"""Register a java UDAF so it can be used in SQL statements.
:param name: name of the UDAF
:param javaClassName: fully qualified name of java class
>>> sqlContext.registerJavaUDAF("javaUDAF",
... "test.org.apache.spark.sql.MyDoubleAvg")
>>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"])
>>> df.registerTempTable("df")
>>> sqlContext.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect()
[Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)]
"""
self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName)
warnings.warn(
"Deprecated in 2.3.0. Use spark.udf.registerJavaFunction instead.",
DeprecationWarning)
return self.sparkSession.udf.registerJavaFunction(name, javaClassName, returnType)

# TODO(andrew): delete this once we refactor things to take in SparkSession
def _inferSchema(self, rdd, samplingRatio=None):
Expand Down Expand Up @@ -590,24 +507,6 @@ def refreshTable(self, tableName):
self._ssql_ctx.refreshTable(tableName)


class UDFRegistration(object):
"""Wrapper for user-defined function registration."""

def __init__(self, sqlContext):
self.sqlContext = sqlContext

def register(self, name, f, returnType=None):
return self.sqlContext.registerFunction(name, f, returnType)

def registerJavaFunction(self, name, javaClassName, returnType=None):
self.sqlContext.registerJavaFunction(name, javaClassName, returnType)

def registerJavaUDAF(self, name, javaClassName):
self.sqlContext.registerJavaUDAF(name, javaClassName)

register.__doc__ = SQLContext.registerFunction.__doc__


def _test():
import os
import doctest
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2103,7 +2103,7 @@ def udf(f=None, returnType=StringType()):
>>> import random
>>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic()
.. note:: The user-defined functions do not support conditional expressions or short curcuiting
.. note:: The user-defined functions do not support conditional expressions or short circuiting
in boolean expressions and it ends up with being executed all internally. If the functions
can fail on special rows, the workaround is to incorporate the condition into the functions.
Expand Down Expand Up @@ -2231,7 +2231,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
... return pd.Series(np.random.randn(len(v))
>>> random = random.asNondeterministic() # doctest: +SKIP
.. note:: The user-defined functions do not support conditional expressions or short curcuiting
.. note:: The user-defined functions do not support conditional expressions or short circuiting
in boolean expressions and it ends up with being executed all internally. If the functions
can fail on special rows, the workaround is to incorporate the condition into the functions.
"""
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/sql/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ def apply(self, udf):
This function does not support partial aggregation, and requires shuffling all the data in
the :class:`DataFrame`.
:param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf`
:param udf: a group map user-defined function returned by
:meth:`pyspark.sql.functions.pandas_udf`.
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

from pyspark import since
from pyspark.rdd import RDD, ignore_unicode_prefix
from pyspark.sql.catalog import Catalog
from pyspark.sql.conf import RuntimeConfig
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import DataFrameReader
Expand Down Expand Up @@ -280,6 +279,7 @@ def catalog(self):
:return: :class:`Catalog`
"""
from pyspark.sql.catalog import Catalog
if not hasattr(self, "_catalog"):
self._catalog = Catalog(self)
return self._catalog
Expand All @@ -291,8 +291,8 @@ def udf(self):
:return: :class:`UDFRegistration`
"""
from pyspark.sql.context import UDFRegistration
return UDFRegistration(self._wrapped)
from pyspark.sql.udf import UDFRegistration
return UDFRegistration(self)

@since(2.0)
def range(self, start, end=None, step=1, numPartitions=None):
Expand Down
20 changes: 20 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,12 @@ def test_udf(self):
[row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
self.assertEqual(row[0], 5)

# This is to check if a deprecated 'SQLContext.registerFunction' can call its alias.
sqlContext = self.spark._wrapped
sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType())
[row] = sqlContext.sql("SELECT oneArg('test')").collect()
self.assertEqual(row[0], 4)

def test_udf2(self):
self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType())
self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\
Expand Down Expand Up @@ -577,11 +583,25 @@ def test_udf_registration_returns_udf(self):
df.select(add_three("id").alias("plus_three")).collect()
)

# This is to check if a 'SQLContext.udf' can call its alias.
sqlContext = self.spark._wrapped
add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType())

self.assertListEqual(
df.selectExpr("add_four(id) AS plus_four").collect(),
df.select(add_four("id").alias("plus_four")).collect()
)

def test_non_existed_udf(self):
spark = self.spark
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"))

# This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias.
sqlContext = spark._wrapped
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf"))

def test_non_existed_udaf(self):
spark = self.spark
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",
Expand Down
Loading

0 comments on commit 3a80cc5

Please sign in to comment.