Skip to content

Commit

Permalink
[SPARK-19019][PYTHON][BRANCH-1.6] Fix hijacked `collections.namedtupl…
Browse files Browse the repository at this point in the history
…e` and port cloudpickle changes for PySpark to work with Python 3.6.0

## What changes were proposed in this pull request?

This PR proposes to backports #16429 to branch-1.6 so that Python 3.6.0 works with Spark 1.6.x.

## How was this patch tested?

Manually, via

```
./run-tests --python-executables=python3.6
```

```
Finished test(python3.6): pyspark.conf (5s)
Finished test(python3.6): pyspark.broadcast (7s)
Finished test(python3.6): pyspark.accumulators (9s)
Finished test(python3.6): pyspark.rdd (16s)
Finished test(python3.6): pyspark.shuffle (0s)
Finished test(python3.6): pyspark.serializers (11s)
Finished test(python3.6): pyspark.profiler (5s)
Finished test(python3.6): pyspark.context (21s)
Finished test(python3.6): pyspark.ml.clustering (12s)
Finished test(python3.6): pyspark.ml.feature (16s)
Finished test(python3.6): pyspark.ml.classification (16s)
Finished test(python3.6): pyspark.ml.recommendation (16s)
Finished test(python3.6): pyspark.ml.tuning (14s)
Finished test(python3.6): pyspark.ml.regression (16s)
Finished test(python3.6): pyspark.ml.evaluation (12s)
Finished test(python3.6): pyspark.ml.tests (17s)
Finished test(python3.6): pyspark.mllib.classification (18s)
Finished test(python3.6): pyspark.mllib.evaluation (12s)
Finished test(python3.6): pyspark.mllib.feature (19s)
Finished test(python3.6): pyspark.mllib.linalg.__init__ (0s)
Finished test(python3.6): pyspark.mllib.fpm (12s)
Finished test(python3.6): pyspark.mllib.clustering (31s)
Finished test(python3.6): pyspark.mllib.random (8s)
Finished test(python3.6): pyspark.mllib.linalg.distributed (17s)
Finished test(python3.6): pyspark.mllib.recommendation (23s)
Finished test(python3.6): pyspark.mllib.stat.KernelDensity (0s)
Finished test(python3.6): pyspark.mllib.stat._statistics (13s)
Finished test(python3.6): pyspark.mllib.regression (22s)
Finished test(python3.6): pyspark.mllib.util (9s)
Finished test(python3.6): pyspark.mllib.tree (14s)
Finished test(python3.6): pyspark.sql.types (9s)
Finished test(python3.6): pyspark.sql.context (16s)
Finished test(python3.6): pyspark.sql.column (14s)
Finished test(python3.6): pyspark.sql.group (16s)
Finished test(python3.6): pyspark.sql.dataframe (25s)
Finished test(python3.6): pyspark.tests (164s)
Finished test(python3.6): pyspark.sql.window (6s)
Finished test(python3.6): pyspark.sql.functions (19s)
Finished test(python3.6): pyspark.streaming.util (0s)
Finished test(python3.6): pyspark.sql.readwriter (24s)
Finished test(python3.6): pyspark.sql.tests (38s)
Finished test(python3.6): pyspark.mllib.tests (133s)
Finished test(python3.6): pyspark.streaming.tests (189s)
Tests passed in 380 seconds
```

Author: hyukjinkwon <[email protected]>

Closes #17375 from HyukjinKwon/SPARK-19019-backport-1.6.
  • Loading branch information
HyukjinKwon authored and holdenk committed Apr 17, 2017
1 parent 23f9faa commit 6b315f3
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 31 deletions.
98 changes: 67 additions & 31 deletions python/pyspark/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from __future__ import print_function

import operator
import opcode
import os
import io
import pickle
Expand All @@ -53,6 +54,8 @@
import itertools
import dis
import traceback
import weakref


if sys.version < '3':
from pickle import Pickler
Expand All @@ -68,10 +71,10 @@
PY3 = True

#relevant opcodes
STORE_GLOBAL = dis.opname.index('STORE_GLOBAL')
DELETE_GLOBAL = dis.opname.index('DELETE_GLOBAL')
LOAD_GLOBAL = dis.opname.index('LOAD_GLOBAL')
GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL]
STORE_GLOBAL = opcode.opmap['STORE_GLOBAL']
DELETE_GLOBAL = opcode.opmap['DELETE_GLOBAL']
LOAD_GLOBAL = opcode.opmap['LOAD_GLOBAL']
GLOBAL_OPS = (STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL)
HAVE_ARGUMENT = dis.HAVE_ARGUMENT
EXTENDED_ARG = dis.EXTENDED_ARG

Expand All @@ -90,6 +93,43 @@ def _builtin_type(name):
return getattr(types, name)


if sys.version_info < (3, 4):
def _walk_global_ops(code):
"""
Yield (opcode, argument number) tuples for all
global-referencing instructions in *code*.
"""
code = getattr(code, 'co_code', b'')
if not PY3:
code = map(ord, code)

n = len(code)
i = 0
extended_arg = 0
while i < n:
op = code[i]
i += 1
if op >= HAVE_ARGUMENT:
oparg = code[i] + code[i + 1] * 256 + extended_arg
extended_arg = 0
i += 2
if op == EXTENDED_ARG:
extended_arg = oparg * 65536
if op in GLOBAL_OPS:
yield op, oparg

else:
def _walk_global_ops(code):
"""
Yield (opcode, argument number) tuples for all
global-referencing instructions in *code*.
"""
for instr in dis.get_instructions(code):
op = instr.opcode
if op in GLOBAL_OPS:
yield op, instr.arg


class CloudPickler(Pickler):

dispatch = Pickler.dispatch.copy()
Expand Down Expand Up @@ -250,38 +290,34 @@ def save_function_tuple(self, func):
write(pickle.TUPLE)
write(pickle.REDUCE) # applies _fill_function on the tuple

@staticmethod
def extract_code_globals(co):
_extract_code_globals_cache = (
weakref.WeakKeyDictionary()
if sys.version_info >= (2, 7) and not hasattr(sys, "pypy_version_info")
else {})

@classmethod
def extract_code_globals(cls, co):
"""
Find all globals names read or written to by codeblock co
"""
code = co.co_code
if not PY3:
code = [ord(c) for c in code]
names = co.co_names
out_names = set()

n = len(code)
i = 0
extended_arg = 0
while i < n:
op = code[i]
out_names = cls._extract_code_globals_cache.get(co)
if out_names is None:
try:
names = co.co_names
except AttributeError:
# PyPy "builtin-code" object
out_names = set()
else:
out_names = set(names[oparg]
for op, oparg in _walk_global_ops(co))

i += 1
if op >= HAVE_ARGUMENT:
oparg = code[i] + code[i+1] * 256 + extended_arg
extended_arg = 0
i += 2
if op == EXTENDED_ARG:
extended_arg = oparg*65536
if op in GLOBAL_OPS:
out_names.add(names[oparg])
# see if nested function have any global refs
if co.co_consts:
for const in co.co_consts:
if type(const) is types.CodeType:
out_names |= cls.extract_code_globals(const)

# see if nested function have any global refs
if co.co_consts:
for const in co.co_consts:
if type(const) is types.CodeType:
out_names |= CloudPickler.extract_code_globals(const)
cls._extract_code_globals_cache[co] = out_names

return out_names

Expand Down
20 changes: 20 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,18 +370,38 @@ def _hijack_namedtuple():
return

global _old_namedtuple # or it will put in closure
global _old_namedtuple_kwdefaults # or it will put in closure too

def _copy_func(f):
return types.FunctionType(f.__code__, f.__globals__, f.__name__,
f.__defaults__, f.__closure__)

def _kwdefaults(f):
# __kwdefaults__ contains the default values of keyword-only arguments which are
# introduced from Python 3. The possible cases for __kwdefaults__ in namedtuple
# are as below:
#
# - Does not exist in Python 2.
# - Returns None in <= Python 3.5.x.
# - Returns a dictionary containing the default values to the keys from Python 3.6.x
# (See https://bugs.python.org/issue25628).
kargs = getattr(f, "__kwdefaults__", None)
if kargs is None:
return {}
else:
return kargs

_old_namedtuple = _copy_func(collections.namedtuple)
_old_namedtuple_kwdefaults = _kwdefaults(collections.namedtuple)

def namedtuple(*args, **kwargs):
for k, v in _old_namedtuple_kwdefaults.items():
kwargs[k] = kwargs.get(k, v)
cls = _old_namedtuple(*args, **kwargs)
return _hack_namedtuple(cls)

# replace namedtuple with new one
collections.namedtuple.__globals__["_old_namedtuple_kwdefaults"] = _old_namedtuple_kwdefaults
collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple
collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple
collections.namedtuple.__code__ = namedtuple.__code__
Expand Down

0 comments on commit 6b315f3

Please sign in to comment.