Skip to content

Commit

Permalink
[SPARK-13697] [PYSPARK] Fix the missing module name of TransformFunct…
Browse files Browse the repository at this point in the history
…ionSerializer.loads

## What changes were proposed in this pull request?

Set the function's module name to `__main__` if it's missing in `TransformFunctionSerializer.loads`.

## How was this patch tested?

Manually test in the shell.

Before this patch:
```
>>> from pyspark.streaming import StreamingContext
>>> from pyspark.streaming.util import TransformFunction
>>> ssc = StreamingContext(sc, 1)
>>> func = TransformFunction(sc, lambda x: x, sc.serializer)
>>> func.rdd_wrapper(lambda x: x)
TransformFunction(<function <lambda> at 0x106ac8b18>)
>>> bytes = bytearray(ssc._transformerSerializer.serializer.dumps((func.func, func.rdd_wrap_func, func.deserializers)))
>>> func2 = ssc._transformerSerializer.loads(bytes)
>>> print(func2.func.__module__)
None
>>> print(func2.rdd_wrap_func.__module__)
None
>>>
```
After this patch:
```
>>> from pyspark.streaming import StreamingContext
>>> from pyspark.streaming.util import TransformFunction
>>> ssc = StreamingContext(sc, 1)
>>> func = TransformFunction(sc, lambda x: x, sc.serializer)
>>> func.rdd_wrapper(lambda x: x)
TransformFunction(<function <lambda> at 0x108bf1b90>)
>>> bytes = bytearray(ssc._transformerSerializer.serializer.dumps((func.func, func.rdd_wrap_func, func.deserializers)))
>>> func2 = ssc._transformerSerializer.loads(bytes)
>>> print(func2.func.__module__)
__main__
>>> print(func2.rdd_wrap_func.__module__)
__main__
>>>
```

Author: Shixiong Zhu <[email protected]>

Closes #11535 from zsxwing/loads-module.

(cherry picked from commit ee913e6)
Signed-off-by: Davies Liu <[email protected]>
  • Loading branch information
zsxwing authored and davies committed Mar 6, 2016
1 parent ffaf7c0 commit 704a54c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python/pyspark/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def save_function_tuple(self, func):
save(f_globals)
save(defaults)
save(dct)
save(func.__module__)
write(pickle.TUPLE)
write(pickle.REDUCE) # applies _fill_function on the tuple

Expand Down Expand Up @@ -698,13 +699,14 @@ def _genpartial(func, args, kwds):
return partial(func, *args, **kwds)


def _fill_function(func, globals, defaults, dict):
def _fill_function(func, globals, defaults, dict, module):
""" Fills in the rest of function data into the skeleton function object
that were created via _make_skel_func().
"""
func.__globals__.update(globals)
func.__defaults__ = defaults
func.__dict__ = dict
func.__module__ = module

return func

Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,12 @@ def test_itemgetter(self):
getter2 = ser.loads(ser.dumps(getter))
self.assertEqual(getter(d), getter2(d))

def test_function_module_name(self):
ser = CloudPickleSerializer()
func = lambda x: x
func2 = ser.loads(ser.dumps(func))
self.assertEqual(func.__module__, func2.__module__)

def test_attrgetter(self):
from operator import attrgetter
ser = CloudPickleSerializer()
Expand Down

0 comments on commit 704a54c

Please sign in to comment.