Skip to content

Commit

Permalink
[SPARK-23334][SQL][PYTHON] Fix pandas_udf with return type StringType…
Browse files Browse the repository at this point in the history
…() to handle str type properly in Python 2.

## What changes were proposed in this pull request?

In Python 2, when `pandas_udf` tries to return string type value created in the udf with `".."`, the execution fails. E.g.,

```python
from pyspark.sql.functions import pandas_udf, col
import pandas as pd

df = spark.range(10)
str_f = pandas_udf(lambda x: pd.Series(["%s" % i for i in x]), "string")
df.select(str_f(col('id'))).show()
```

raises the following exception:

```
...

java.lang.AssertionError: assertion failed: Invalid schema from pandas_udf: expected StringType, got BinaryType
	at scala.Predef$.assert(Predef.scala:170)
	at org.apache.spark.sql.execution.python.ArrowEvalPythonExec$$anon$2.<init>(ArrowEvalPythonExec.scala:93)

...
```

Seems like pyarrow ignores `type` parameter for `pa.Array.from_pandas()` and consider it as binary type when the type is string type and the string values are `str` instead of `unicode` in Python 2.

This pr adds a workaround for the case.

## How was this patch tested?

Added a test and existing tests.

Author: Takuya UESHIN <[email protected]>

Closes apache#20507 from ueshin/issues/SPARK-23334.
  • Loading branch information
ueshin authored and HyukjinKwon committed Feb 6, 2018
1 parent 8141c3e commit 63c5bf1
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ def create_array(s, t):
s = _check_series_convert_timestamps_internal(s.fillna(0), timezone)
# TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2
return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False)
elif t is not None and pa.types.is_string(t) and sys.version < '3':
# TODO: need decode before converting to Arrow in Python 2
return pa.Array.from_pandas(s.apply(
lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t)
return pa.Array.from_pandas(s, mask=mask, type=t)

arrs = [create_array(s, t) for s, t in series]
Expand Down
9 changes: 9 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3922,6 +3922,15 @@ def test_vectorized_udf_null_string(self):
res = df.select(str_f(col('str')))
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_string_in_udf(self):
from pyspark.sql.functions import pandas_udf, col
import pandas as pd
df = self.spark.range(10)
str_f = pandas_udf(lambda x: pd.Series(map(str, x)), StringType())
actual = df.select(str_f(col('id')))
expected = df.select(col('id').cast('string'))
self.assertEquals(expected.collect(), actual.collect())

def test_vectorized_udf_datatype_string(self):
from pyspark.sql.functions import pandas_udf, col
df = self.spark.range(10).select(
Expand Down

0 comments on commit 63c5bf1

Please sign in to comment.