diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 822ae46e45111..94168f041d91c 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -43,6 +43,7 @@ from __future__ import print_function import operator +import opcode import os import io import pickle @@ -53,6 +54,8 @@ import itertools import dis import traceback +import weakref + if sys.version < '3': from pickle import Pickler @@ -68,10 +71,10 @@ PY3 = True #relevant opcodes -STORE_GLOBAL = dis.opname.index('STORE_GLOBAL') -DELETE_GLOBAL = dis.opname.index('DELETE_GLOBAL') -LOAD_GLOBAL = dis.opname.index('LOAD_GLOBAL') -GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL] +STORE_GLOBAL = opcode.opmap['STORE_GLOBAL'] +DELETE_GLOBAL = opcode.opmap['DELETE_GLOBAL'] +LOAD_GLOBAL = opcode.opmap['LOAD_GLOBAL'] +GLOBAL_OPS = (STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL) HAVE_ARGUMENT = dis.HAVE_ARGUMENT EXTENDED_ARG = dis.EXTENDED_ARG @@ -90,6 +93,43 @@ def _builtin_type(name): return getattr(types, name) +if sys.version_info < (3, 4): + def _walk_global_ops(code): + """ + Yield (opcode, argument number) tuples for all + global-referencing instructions in *code*. + """ + code = getattr(code, 'co_code', b'') + if not PY3: + code = map(ord, code) + + n = len(code) + i = 0 + extended_arg = 0 + while i < n: + op = code[i] + i += 1 + if op >= HAVE_ARGUMENT: + oparg = code[i] + code[i + 1] * 256 + extended_arg + extended_arg = 0 + i += 2 + if op == EXTENDED_ARG: + extended_arg = oparg * 65536 + if op in GLOBAL_OPS: + yield op, oparg + +else: + def _walk_global_ops(code): + """ + Yield (opcode, argument number) tuples for all + global-referencing instructions in *code*. + """ + for instr in dis.get_instructions(code): + op = instr.opcode + if op in GLOBAL_OPS: + yield op, instr.arg + + class CloudPickler(Pickler): dispatch = Pickler.dispatch.copy() @@ -250,38 +290,34 @@ def save_function_tuple(self, func): write(pickle.TUPLE) write(pickle.REDUCE) # applies _fill_function on the tuple - @staticmethod - def extract_code_globals(co): + _extract_code_globals_cache = ( + weakref.WeakKeyDictionary() + if sys.version_info >= (2, 7) and not hasattr(sys, "pypy_version_info") + else {}) + + @classmethod + def extract_code_globals(cls, co): """ Find all globals names read or written to by codeblock co """ - code = co.co_code - if not PY3: - code = [ord(c) for c in code] - names = co.co_names - out_names = set() - - n = len(code) - i = 0 - extended_arg = 0 - while i < n: - op = code[i] + out_names = cls._extract_code_globals_cache.get(co) + if out_names is None: + try: + names = co.co_names + except AttributeError: + # PyPy "builtin-code" object + out_names = set() + else: + out_names = set(names[oparg] + for op, oparg in _walk_global_ops(co)) - i += 1 - if op >= HAVE_ARGUMENT: - oparg = code[i] + code[i+1] * 256 + extended_arg - extended_arg = 0 - i += 2 - if op == EXTENDED_ARG: - extended_arg = oparg*65536 - if op in GLOBAL_OPS: - out_names.add(names[oparg]) + # see if nested function have any global refs + if co.co_consts: + for const in co.co_consts: + if type(const) is types.CodeType: + out_names |= cls.extract_code_globals(const) - # see if nested function have any global refs - if co.co_consts: - for const in co.co_consts: - if type(const) is types.CodeType: - out_names |= CloudPickler.extract_code_globals(const) + cls._extract_code_globals_cache[co] = out_names return out_names diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 2a1326947f4f5..a9e14b8f7033b 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -370,18 +370,38 @@ def _hijack_namedtuple(): return global _old_namedtuple # or it will put in closure + global _old_namedtuple_kwdefaults # or it will put in closure too def _copy_func(f): return types.FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__, f.__closure__) + def _kwdefaults(f): + # __kwdefaults__ contains the default values of keyword-only arguments which are + # introduced from Python 3. The possible cases for __kwdefaults__ in namedtuple + # are as below: + # + # - Does not exist in Python 2. + # - Returns None in <= Python 3.5.x. + # - Returns a dictionary containing the default values to the keys from Python 3.6.x + # (See https://bugs.python.org/issue25628). + kargs = getattr(f, "__kwdefaults__", None) + if kargs is None: + return {} + else: + return kargs + _old_namedtuple = _copy_func(collections.namedtuple) + _old_namedtuple_kwdefaults = _kwdefaults(collections.namedtuple) def namedtuple(*args, **kwargs): + for k, v in _old_namedtuple_kwdefaults.items(): + kwargs[k] = kwargs.get(k, v) cls = _old_namedtuple(*args, **kwargs) return _hack_namedtuple(cls) # replace namedtuple with new one + collections.namedtuple.__globals__["_old_namedtuple_kwdefaults"] = _old_namedtuple_kwdefaults collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple collections.namedtuple.__code__ = namedtuple.__code__