Skip to content

Commit

Permalink
Attach python lifetime to shared_ptr passed to C++
Browse files Browse the repository at this point in the history
- Reference cycles are possible as a result, but shared_ptr is already susceptible to this in C++
  • Loading branch information
virtuald committed Feb 3, 2021
1 parent 918c909 commit 2b93ea7
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 1 deletion.
39 changes: 38 additions & 1 deletion include/pybind11/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#pragma once

#include "gil.h"
#include "pytypes.h"
#include "detail/typeid.h"
#include "detail/descr.h"
Expand Down Expand Up @@ -1524,6 +1525,42 @@ struct holder_helper {
static auto get(const T &p) -> decltype(p.get()) { return p.get(); }
};

/// Another helper class for holders that helps construct derivative holders from
/// the original holder
template <typename T>
struct holder_retriever {
static auto get_derivative_holder(const value_and_holder &v_h) -> decltype(v_h.template holder<T>()) {
return v_h.template holder<T>();
}
};

template <typename T>
struct holder_retriever<std::shared_ptr<T>> {
struct shared_ptr_deleter {
// Note: deleter destructor fails on MSVC 2015 and GCC 4.8, so we manually
// call dec_ref here instead
handle ref;
void operator()(T *) {
gil_scoped_acquire gil;
ref.dec_ref();
}
};

static auto get_derivative_holder(const value_and_holder &v_h) -> std::shared_ptr<T> {
// The shared_ptr is always given to C++ code, so construct a new shared_ptr
// that is given a custom deleter. The custom deleter increments the python
// reference count to bind the python instance lifetime with the lifetime
// of the shared_ptr.
//
// This enables things like passing the last python reference of a subclass to a
// C++ function without the python reference dying.
//
// Reference cycles will cause a leak, but this is a limitation of shared_ptr
return std::shared_ptr<T>((T*)v_h.value_ptr(),
shared_ptr_deleter{handle((PyObject*)v_h.inst).inc_ref()});
}
};

/// Type caster for holder types like std::shared_ptr, etc.
/// The SFINAE hook is provided to help work around the current lack of support
/// for smart-pointer interoperability. Please consider it an implementation
Expand Down Expand Up @@ -1566,7 +1603,7 @@ struct copyable_holder_caster : public type_caster_base<type> {
bool load_value(value_and_holder &&v_h) {
if (v_h.holder_constructed()) {
value = v_h.value_ptr();
holder = v_h.template holder<holder_type>();
holder = holder_retriever<holder_type>::get_derivative_holder(v_h);
return true;
} else {
throw cast_error("Unable to cast from non-held to held instance (T& to Holder<T>) "
Expand Down
33 changes: 33 additions & 0 deletions tests/test_smart_ptr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,4 +397,37 @@ TEST_SUBMODULE(smart_ptr, m) {
list.append(py::cast(e));
return list;
});

// For testing whether a python subclass of a C++ object dies when the
// last python reference is lost
struct SpBase {
// returns true if the base virtual function is called
virtual bool is_base_used() { return true; }

SpBase() = default;
SpBase(const SpBase&) = delete;
virtual ~SpBase() = default;
};

struct PySpBase : SpBase {
bool is_base_used() override { PYBIND11_OVERRIDE(bool, SpBase, is_base_used); }
};

struct SpBaseTester {
std::shared_ptr<SpBase> get_object() { return m_obj; }
void set_object(std::shared_ptr<SpBase> obj) { m_obj = obj; }
bool is_base_used() { return m_obj->is_base_used(); }
std::shared_ptr<SpBase> m_obj;
};

py::class_<SpBase, std::shared_ptr<SpBase>, PySpBase>(m, "SpBase")
.def(py::init<>())
.def("is_base_used", &SpBase::is_base_used);

py::class_<SpBaseTester>(m, "SpBaseTester")
.def(py::init<>())
.def("get_object", &SpBaseTester::get_object)
.def("set_object", &SpBaseTester::set_object)
.def("is_base_used", &SpBaseTester::is_base_used)
.def_readwrite("obj", &SpBaseTester::m_obj);
}
62 changes: 62 additions & 0 deletions tests/test_smart_ptr.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,65 @@ def test_shared_ptr_gc():
pytest.gc_collect()
for i, v in enumerate(el.get()):
assert i == v.value()


def test_shared_ptr_cpp_arg():
import weakref

class PyChild(m.SpBase):
def is_base_used(self):
return False

tester = m.SpBaseTester()

obj = PyChild()
objref = weakref.ref(obj)

# Pass the last python reference to the C++ function
tester.set_object(obj)
del obj
pytest.gc_collect()

# python reference is still around since C++ has it now
assert objref() is not None
assert tester.is_base_used() is False
assert tester.obj.is_base_used() is False
assert tester.get_object() is objref()


def test_shared_ptr_cpp_prop():
class PyChild(m.SpBase):
def is_base_used(self):
return False

tester = m.SpBaseTester()

# Set the last python reference as a property of the C++ object
tester.obj = PyChild()
pytest.gc_collect()

# python reference is still around since C++ has it now
assert tester.is_base_used() is False
assert tester.obj.is_base_used() is False


def test_shared_ptr_arg_identity():
import weakref

tester = m.SpBaseTester()

obj = m.SpBase()
objref = weakref.ref(obj)

tester.set_object(obj)
del obj
pytest.gc_collect()

# python reference is still around since C++ has it
assert objref() is not None
assert tester.get_object() is objref()

# python reference disappears once the C++ object releases it
tester.set_object(None)
pytest.gc_collect()
assert objref() is None

0 comments on commit 2b93ea7

Please sign in to comment.