Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-19019][PYTHON] Fix hijacked collections.namedtuple and port cloudpickle changes for PySpark to work with Python 3.6.0 #16429

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 67 additions & 31 deletions python/pyspark/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from __future__ import print_function

import operator
import opcode
import os
import io
import pickle
Expand All @@ -53,6 +54,8 @@
import itertools
import dis
import traceback
import weakref


if sys.version < '3':
from pickle import Pickler
Expand All @@ -68,10 +71,10 @@
PY3 = True

#relevant opcodes
STORE_GLOBAL = dis.opname.index('STORE_GLOBAL')
DELETE_GLOBAL = dis.opname.index('DELETE_GLOBAL')
LOAD_GLOBAL = dis.opname.index('LOAD_GLOBAL')
GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL]
STORE_GLOBAL = opcode.opmap['STORE_GLOBAL']
DELETE_GLOBAL = opcode.opmap['DELETE_GLOBAL']
LOAD_GLOBAL = opcode.opmap['LOAD_GLOBAL']
GLOBAL_OPS = (STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL)
HAVE_ARGUMENT = dis.HAVE_ARGUMENT
EXTENDED_ARG = dis.EXTENDED_ARG

Expand All @@ -90,6 +93,43 @@ def _builtin_type(name):
return getattr(types, name)


if sys.version_info < (3, 4):
def _walk_global_ops(code):
"""
Yield (opcode, argument number) tuples for all
global-referencing instructions in *code*.
"""
code = getattr(code, 'co_code', b'')
if not PY3:
code = map(ord, code)

n = len(code)
i = 0
extended_arg = 0
while i < n:
op = code[i]
i += 1
if op >= HAVE_ARGUMENT:
oparg = code[i] + code[i + 1] * 256 + extended_arg
extended_arg = 0
i += 2
if op == EXTENDED_ARG:
extended_arg = oparg * 65536
if op in GLOBAL_OPS:
yield op, oparg

else:
def _walk_global_ops(code):
"""
Yield (opcode, argument number) tuples for all
global-referencing instructions in *code*.
"""
for instr in dis.get_instructions(code):
op = instr.opcode
if op in GLOBAL_OPS:
yield op, instr.arg


class CloudPickler(Pickler):

dispatch = Pickler.dispatch.copy()
Expand Down Expand Up @@ -260,38 +300,34 @@ def save_function_tuple(self, func):
write(pickle.TUPLE)
write(pickle.REDUCE) # applies _fill_function on the tuple

@staticmethod
def extract_code_globals(co):
_extract_code_globals_cache = (
weakref.WeakKeyDictionary()
if sys.version_info >= (2, 7) and not hasattr(sys, "pypy_version_info")
else {})

@classmethod
def extract_code_globals(cls, co):
"""
Find all globals names read or written to by codeblock co
"""
code = co.co_code
if not PY3:
code = [ord(c) for c in code]
names = co.co_names
out_names = set()

n = len(code)
i = 0
extended_arg = 0
while i < n:
op = code[i]
out_names = cls._extract_code_globals_cache.get(co)
if out_names is None:
try:
names = co.co_names
except AttributeError:
# PyPy "builtin-code" object
out_names = set()
else:
out_names = set(names[oparg]
for op, oparg in _walk_global_ops(co))

i += 1
if op >= HAVE_ARGUMENT:
oparg = code[i] + code[i+1] * 256 + extended_arg
extended_arg = 0
i += 2
if op == EXTENDED_ARG:
extended_arg = oparg*65536
if op in GLOBAL_OPS:
out_names.add(names[oparg])
# see if nested function have any global refs
if co.co_consts:
for const in co.co_consts:
if type(const) is types.CodeType:
out_names |= cls.extract_code_globals(const)

# see if nested function have any global refs
if co.co_consts:
for const in co.co_consts:
if type(const) is types.CodeType:
out_names |= CloudPickler.extract_code_globals(const)
cls._extract_code_globals_cache[co] = out_names

return out_names

Expand Down
20 changes: 20 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,18 +382,38 @@ def _hijack_namedtuple():
return

global _old_namedtuple # or it will put in closure
global _old_namedtuple_kwdefaults # or it will put in closure too

def _copy_func(f):
return types.FunctionType(f.__code__, f.__globals__, f.__name__,
f.__defaults__, f.__closure__)

def _kwdefaults(f):
# __kwdefaults__ contains the default values of keyword-only arguments which are
# introduced from Python 3. The possible cases for __kwdefaults__ in namedtuple
# are as below:
#
# - Does not exist in Python 2.
# - Returns None in <= Python 3.5.x.
# - Returns a dictionary containing the default values to the keys from Python 3.6.x
# (See https://bugs.python.org/issue25628).
kargs = getattr(f, "__kwdefaults__", None)
if kargs is None:
return {}
else:
return kargs

_old_namedtuple = _copy_func(collections.namedtuple)
_old_namedtuple_kwdefaults = _kwdefaults(collections.namedtuple)

def namedtuple(*args, **kwargs):
for k, v in _old_namedtuple_kwdefaults.items():
kwargs[k] = kwargs.get(k, v)
cls = _old_namedtuple(*args, **kwargs)
return _hack_namedtuple(cls)

# replace namedtuple with new one
collections.namedtuple.__globals__["_old_namedtuple_kwdefaults"] = _old_namedtuple_kwdefaults
collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple
collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple
collections.namedtuple.__code__ = namedtuple.__code__
Expand Down