diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index 287691f55..c89ad2aef 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -359,10 +359,13 @@ def save_global(self, obj, name=None, pack=struct.pack): def save_instancemethod(self, obj): # Memoization rarely is ever useful due to python bounding - if PY3: - self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj) + if obj.__self__ is None: + self.save_reduce(getattr, (obj.im_class, obj.__name__)) else: - self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__), + if PY3: + self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj) + else: + self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__), obj=obj) dispatch[types.MethodType] = save_instancemethod @@ -698,3 +701,18 @@ def _make_skel_func(code, closures, base_globals = None): def _getobject(modname, attribute): mod = __import__(modname, fromlist=[attribute]) return mod.__dict__[attribute] + + +""" Use copy_reg to extend global pickle definitions """ + +if sys.version_info < (3, 4): + method_descriptor = type(str.upper) + + def _reduce_method_descriptor(obj): + return (getattr, (obj.__objclass__, obj.__name__)) + + try: + import copy_reg as copyreg + except ImportError: + import copyreg + copyreg.pickle(method_descriptor, _reduce_method_descriptor) diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index 8ecf1237a..fd8f399cc 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -237,5 +237,21 @@ def test_cm(cls): self.assertEqual(A.test_sm(), "sm") self.assertEqual(A.test_cm(), "cm") + def test_method_descriptors(self): + f = pickle_depickle(str.upper) + self.assertEqual(f('abc'), 'ABC') + + def test_instancemethods_without_self(self): + class F(object): + def f(self, x): + return x + 1 + + g = pickle_depickle(F.f) + self.assertEqual(g.__name__, F.f.__name__) + if sys.version_info[0] < 3: + self.assertEqual(g.im_class.__name__, F.f.im_class.__name__) + # self.assertEqual(g(F(), 1), 2) # still fails + + if __name__ == '__main__': unittest.main()