Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pickle setstate: setattr __dict__ only if not empty #2972

Merged
merged 8 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion include/pybind11/detail/init.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,13 @@ template <typename Class, typename T, typename O,
enable_if_t<std::is_convertible<O, handle>::value, int> = 0>
void setstate(value_and_holder &v_h, std::pair<T, O> &&result, bool need_alias) {
construct<Class>(v_h, std::move(result.first), need_alias);
setattr((PyObject *) v_h.inst, "__dict__", result.second);
auto d = handle(result.second);
if (PyDict_Check(d.ptr()) && PyDict_Size(d.ptr()) == 0) {
// Skipping setattr below, to not force use of py::dynamic_attr() for Class unnecessarily.
// See PR #2972 for details.
return;
}
setattr((PyObject *) v_h.inst, "__dict__", d);
EricCousineau-TRI marked this conversation as resolved.
Show resolved Hide resolved
}

/// Implementation for py::pickle(GetState, SetState)
Expand Down
56 changes: 56 additions & 0 deletions tests/test_pickling.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,68 @@
// clang-format off
/*
tests/test_pickling.cpp -- pickle support

Copyright (c) 2016 Wenzel Jakob <[email protected]>
Copyright (c) 2021 The Pybind Development Team.

All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/

#include "pybind11_tests.h"

// clang-format on

#include <memory>
#include <stdexcept>
#include <utility>

namespace exercise_trampoline {

struct SimpleBase {
int num = 0;
virtual ~SimpleBase() = default;

// For compatibility with old clang versions:
SimpleBase() = default;
SimpleBase(const SimpleBase &) = default;
};

struct SimpleBaseTrampoline : SimpleBase {};

struct SimpleCppDerived : SimpleBase {};

void wrap(py::module m) {
py::class_<SimpleBase, SimpleBaseTrampoline>(m, "SimpleBase")
.def(py::init<>())
.def_readwrite("num", &SimpleBase::num)
.def(py::pickle(
[](const py::object &self) {
py::dict d;
if (py::hasattr(self, "__dict__"))
d = self.attr("__dict__");
return py::make_tuple(self.attr("num"), d);
},
[](const py::tuple &t) {
if (t.size() != 2)
throw std::runtime_error("Invalid state!");
auto cpp_state = std::unique_ptr<SimpleBase>(new SimpleBaseTrampoline);
cpp_state->num = t[0].cast<int>();
auto py_state = t[1].cast<py::dict>();
return std::make_pair(std::move(cpp_state), py_state);
}));

m.def("make_SimpleCppDerivedAsBase",
[]() { return std::unique_ptr<SimpleBase>(new SimpleCppDerived); });
m.def("check_dynamic_cast_SimpleCppDerived", [](const SimpleBase *base_ptr) {
return dynamic_cast<const SimpleCppDerived *>(base_ptr) != nullptr;
});
}

} // namespace exercise_trampoline

// clang-format off

TEST_SUBMODULE(pickling, m) {
// test_roundtrip
class Pickleable {
Expand Down Expand Up @@ -130,4 +184,6 @@ TEST_SUBMODULE(pickling, m) {
return std::make_pair(cpp_state, py_state);
}));
#endif

exercise_trampoline::wrap(m);
}
36 changes: 36 additions & 0 deletions tests/test_pickling.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,39 @@ def test_enum_pickle():

data = pickle.dumps(e.EOne, 2)
assert e.EOne == pickle.loads(data)


#
# exercise_trampoline
#
class SimplePyDerived(m.SimpleBase):
pass


def test_roundtrip_simple_py_derived():
p = SimplePyDerived()
p.num = 202
p.stored_in_dict = 303
data = pickle.dumps(p, pickle.HIGHEST_PROTOCOL)
p2 = pickle.loads(data)
assert isinstance(p2, SimplePyDerived)
assert p2.num == 202
assert p2.stored_in_dict == 303


def test_roundtrip_simple_cpp_derived():
p = m.make_SimpleCppDerivedAsBase()
assert m.check_dynamic_cast_SimpleCppDerived(p)
p.num = 404
if not env.PYPY:
# To ensure that this unit test is not accidentally invalidated.
with pytest.raises(AttributeError):
# Mimics the `setstate` C++ implementation.
setattr(p, "__dict__", {}) # noqa: B010
data = pickle.dumps(p, pickle.HIGHEST_PROTOCOL)
p2 = pickle.loads(data)
assert isinstance(p2, m.SimpleBase)
EricCousineau-TRI marked this conversation as resolved.
Show resolved Hide resolved
assert p2.num == 404
# Issue #3062: pickleable base C++ classes can incur object slicing
# if derived typeid is not registered with pybind11
assert not m.check_dynamic_cast_SimpleCppDerived(p2)