Skip to content

Commit

Permalink
Change implementation of the `__init__() must be called when overridi…
Browse files Browse the repository at this point in the history
…ng __init__` safety feature to work for any metaclass. (#30095)

* Also wrap with `py::metaclass((PyObject *) &PyType_Type)`

* Transfer additional tests from PyCLIF python_multiple_inheritance_test.py

* Expand tests to fully cover wrapping with alternative metaclasses.

* * Factor out `ensure_base_init_functions_were_called()`.

* Call from new `tp_init_intercepted()` (adopting mechanism first added in PyCLIF: google/clif@7cba87d).

* Remove `pybind11_meta_call()` (which was added with pybind/pybind11#2152).

* Bug fix (maybe actually two bugs?): simplify condition to `type->tp_init != tp_init_intercepted`

* Removing `Py_DECREF(self)` that leads to MSAN failure (Google toolchain).

```
==6380==WARNING: MemorySanitizer: use-of-uninitialized-value
    #0 0x5611589c9a58 in Py_DECREF third_party/python_runtime/v3_11/Include/object.h:537:9
...

  Uninitialized value was created by a heap deallocation
    #0 0x5611552757b0 in free third_party/llvm/llvm-project/compiler-rt/lib/msan/msan_interceptors.cpp:218:3
    #1 0x56115898e06b in _PyMem_RawFree third_party/python_runtime/v3_11/Objects/obmalloc.c:154:5
    #2 0x56115898f6ad in PyObject_Free third_party/python_runtime/v3_11/Objects/obmalloc.c:769:5
    #3 0x561158271bcc in PyObject_GC_Del third_party/python_runtime/v3_11/Modules/gcmodule.c:2407:5
    #4 0x7f21224b070c in pybind11_object_dealloc third_party/pybind11/include/pybind11/detail/class.h:483:5
    #5 0x5611589c2ed0 in subtype_dealloc third_party/python_runtime/v3_11/Objects/typeobject.c:1463:5
...
```

* IncludeCleaner fixes (Google toolchain).

* Restore `type->tp_call = pybind11_meta_call;` for PyPy only.

* pytest.skip("ensure_base_init_functions_were_called() does not work with PyPy and Python `type` as metaclass")

* Do not intercept our own `tp_init` function (`pybind11_object_init`).

* Add `derived_tp_init_registry` weakref-based cleanup.

* Replace `assert()` with `if` to resolve erroneous `lambda capture 'type' is not used` diagnostics (many CI jobs; seems to be a clang issue).

* Add `derived_tp_init_registry()->count(type) == 0` condition.

* Changes based on feedback from @rainwoodman

* Use PYBIND11_INIT_SAFETY_CHECKS_VIA_* macros, based on suggestion from @rainwoodman
  • Loading branch information
Ralf W. Grosse-Kunstleve authored Feb 1, 2024
1 parent 80c9ee6 commit 54f8341
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 38 deletions.
82 changes: 67 additions & 15 deletions include/pybind11/detail/class.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
#include "../attr.h"
#include "../options.h"

#include <cassert>
#include <unordered_map>

PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
PYBIND11_NAMESPACE_BEGIN(detail)

Expand Down Expand Up @@ -179,6 +182,36 @@ extern "C" inline PyObject *pybind11_meta_getattro(PyObject *obj, PyObject *name
return PyType_Type.tp_getattro(obj, name);
}

// Ensure that the base __init__ function(s) were called.
// Set TypeError and return false if not.
// CALLER IS RESPONSIBLE for managing the self refcount appropriately.
inline bool ensure_base_init_functions_were_called(PyObject *self) {
values_and_holders vhs(self);
for (const auto &vh : vhs) {
if (!vh.holder_constructed() && !vhs.is_redundant_value_and_holder(vh)) {
PyErr_Format(PyExc_TypeError,
"%.200s.__init__() must be called when overriding __init__",
get_fully_qualified_tp_name(vh.type->type).c_str());
return false;
}
}
return true;
}

// See google/pywrapcc#30095 for background.
#if !defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_INTERCEPTING_TP_INIT) \
&& !defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS)
# if !defined(PYPY_VERSION)
// With CPython the safety checks work for any metaclass.
// However, with PyPy this implementation does not work at all.
# define PYBIND11_INIT_SAFETY_CHECKS_VIA_INTERCEPTING_TP_INIT
# else
// With this the safety checks work only for the default `py::metaclass()`.
# define PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS
# endif
#endif

#if defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS)
/// metaclass `__call__` function that is used to create all pybind11 objects.
extern "C" inline PyObject *pybind11_meta_call(PyObject *type, PyObject *args, PyObject *kwargs) {

Expand All @@ -188,20 +221,14 @@ extern "C" inline PyObject *pybind11_meta_call(PyObject *type, PyObject *args, P
return nullptr;
}

// Ensure that the base __init__ function(s) were called
values_and_holders vhs(self);
for (const auto &vh : vhs) {
if (!vh.holder_constructed() && !vhs.is_redundant_value_and_holder(vh)) {
PyErr_Format(PyExc_TypeError,
"%.200s.__init__() must be called when overriding __init__",
get_fully_qualified_tp_name(vh.type->type).c_str());
Py_DECREF(self);
return nullptr;
}
if (!ensure_base_init_functions_were_called(self)) {
Py_DECREF(self);
return nullptr;
}

return self;
}
#endif

/// Cleanup the type-info for a pybind11-registered type.
extern "C" inline void pybind11_meta_dealloc(PyObject *obj) {
Expand Down Expand Up @@ -268,7 +295,9 @@ inline PyTypeObject *make_default_metaclass() {
type->tp_base = type_incref(&PyType_Type);
type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;

#if defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS)
type->tp_call = pybind11_meta_call;
#endif

type->tp_setattro = pybind11_meta_setattro;
type->tp_getattro = pybind11_meta_getattro;
Expand Down Expand Up @@ -340,6 +369,33 @@ inline bool deregister_instance(instance *self, void *valptr, const type_info *t
return ret;
}

#if defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_INTERCEPTING_TP_INIT)

using derived_tp_init_registry_type = std::unordered_map<PyTypeObject *, initproc>;

inline derived_tp_init_registry_type *derived_tp_init_registry() {
// Intentionally leak the unordered_map:
// https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables
static auto *singleton = new derived_tp_init_registry_type();
return singleton;
}

extern "C" inline int tp_init_with_safety_checks(PyObject *self, PyObject *args, PyObject *kw) {
assert(PyType_Check(self) == 0);
const auto derived_tp_init = derived_tp_init_registry()->find(Py_TYPE(self));
if (derived_tp_init == derived_tp_init_registry()->end()) {
pybind11_fail("FATAL: Internal consistency check failed at " __FILE__
":" PYBIND11_TOSTRING(__LINE__));
}
int status = (*derived_tp_init->second)(self, args, kw);
if (status == 0 && !ensure_base_init_functions_were_called(self)) {
return -1; // No Py_DECREF here.
}
return status;
}

#endif // PYBIND11_INIT_SAFETY_CHECKS_VIA_INTERCEPTING_TP_INIT

/// Instance creation function for all pybind11 types. It allocates the internal instance layout
/// for holding C++ objects and holders. Allocation is done lazily (the first time the instance is
/// cast to a reference or pointer), and initialization is done by an `__init__` function.
Expand All @@ -360,11 +416,7 @@ inline PyObject *make_new_instance(PyTypeObject *type) {
return self;
}

/// Instance creation function for all pybind11 types. It only allocates space for the
/// C++ object, but doesn't call the constructor -- an `__init__` function must do that.
extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *) {
return make_new_instance(type);
}
extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *);

/// An `__init__` function constructs the C++ object. Users should provide at least one
/// of these using `py::init` or directly with `.def(__init__, ...)`. Otherwise, the
Expand Down
27 changes: 27 additions & 0 deletions include/pybind11/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,33 @@ class cpp_function : public function {
}
};

PYBIND11_NAMESPACE_BEGIN(detail)

/// Instance creation function for all pybind11 types. It only allocates space for the
/// C++ object, but doesn't call the constructor -- an `__init__` function must do that.
extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *) {
#if defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_INTERCEPTING_TP_INIT)
if (type->tp_init != pybind11_object_init && type->tp_init != tp_init_with_safety_checks
&& derived_tp_init_registry()->count(type) == 0) {
weakref((PyObject *) type, cpp_function([type](handle wr) {
auto num_erased = derived_tp_init_registry()->erase(type);
if (num_erased != 1) {
pybind11_fail("FATAL: Internal consistency check failed at " __FILE__
":" PYBIND11_TOSTRING(__LINE__) ": num_erased="
+ std::to_string(num_erased));
}
wr.dec_ref();
}))
.release();
(*derived_tp_init_registry())[type] = type->tp_init;
type->tp_init = tp_init_with_safety_checks;
}
#endif // PYBIND11_INIT_SAFETY_CHECKS_VIA_INTERCEPTING_TP_INIT
return make_new_instance(type);
}

PYBIND11_NAMESPACE_END(detail)

/// Wrapper for Python extension modules
class module_ : public object {
public:
Expand Down
46 changes: 30 additions & 16 deletions tests/test_python_multiple_inheritance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ namespace test_python_multiple_inheritance {
// Copied from:
// https://github.com/google/clif/blob/5718e4d0807fd3b6a8187dde140069120b81ecef/clif/testing/python_multiple_inheritance.h

template <int> // Using int as a trick to easily generate a series of types.
struct CppBase {
explicit CppBase(int value) : base_value(value) {}
int get_base_value() const { return base_value; }
Expand All @@ -14,32 +15,45 @@ struct CppBase {
int base_value;
};

struct CppDrvd : CppBase {
explicit CppDrvd(int value) : CppBase(value), drvd_value(value * 3) {}
template <int SerNo>
struct CppDrvd : CppBase<SerNo> {
explicit CppDrvd(int value) : CppBase<SerNo>(value), drvd_value(value * 3) {}
int get_drvd_value() const { return drvd_value; }
void reset_drvd_value(int new_value) { drvd_value = new_value; }

int get_base_value_from_drvd() const { return get_base_value(); }
void reset_base_value_from_drvd(int new_value) { reset_base_value(new_value); }
int get_base_value_from_drvd() const { return CppBase<SerNo>::get_base_value(); }
void reset_base_value_from_drvd(int new_value) { CppBase<SerNo>::reset_base_value(new_value); }

private:
int drvd_value;
};

template <int SerNo, typename... Extra>
void wrap_classes(py::module_ &m, const char *name_base, const char *name_drvd, Extra... extra) {
py::class_<CppBase<SerNo>>(m, name_base, std::forward<Extra>(extra)...)
.def(py::init<int>())
.def("get_base_value", &CppBase<SerNo>::get_base_value)
.def("reset_base_value", &CppBase<SerNo>::reset_base_value);

py::class_<CppDrvd<SerNo>, CppBase<SerNo>>(m, name_drvd, std::forward<Extra>(extra)...)
.def(py::init<int>())
.def("get_drvd_value", &CppDrvd<SerNo>::get_drvd_value)
.def("reset_drvd_value", &CppDrvd<SerNo>::reset_drvd_value)
.def("get_base_value_from_drvd", &CppDrvd<SerNo>::get_base_value_from_drvd)
.def("reset_base_value_from_drvd", &CppDrvd<SerNo>::reset_base_value_from_drvd);
}

} // namespace test_python_multiple_inheritance

TEST_SUBMODULE(python_multiple_inheritance, m) {
using namespace test_python_multiple_inheritance;

py::class_<CppBase>(m, "CppBase")
.def(py::init<int>())
.def("get_base_value", &CppBase::get_base_value)
.def("reset_base_value", &CppBase::reset_base_value);

py::class_<CppDrvd, CppBase>(m, "CppDrvd")
.def(py::init<int>())
.def("get_drvd_value", &CppDrvd::get_drvd_value)
.def("reset_drvd_value", &CppDrvd::reset_drvd_value)
.def("get_base_value_from_drvd", &CppDrvd::get_base_value_from_drvd)
.def("reset_base_value_from_drvd", &CppDrvd::reset_base_value_from_drvd);
wrap_classes<0>(m, "CppBase0", "CppDrvd0");
wrap_classes<1>(m, "CppBase1", "CppDrvd1", py::metaclass((PyObject *) &PyType_Type));

m.attr("if_defined_PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS") =
#if defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS)
true;
#else
false;
#endif
}
117 changes: 110 additions & 7 deletions tests/test_python_multiple_inheritance.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,78 @@
# Adapted from:
# https://github.com/google/clif/blob/5718e4d0807fd3b6a8187dde140069120b81ecef/clif/testing/python/python_multiple_inheritance_test.py
# https://github.com/google/clif/blob/7d388e1de7db5beeb3d7429c18a2776d8188f44f/clif/testing/python/python_multiple_inheritance_test.py

import pytest

from pybind11_tests import python_multiple_inheritance as m

#
# Using default py::metaclass() (used with py::class_<> for CppBase0, CppDrvd0):
#


class PC0(m.CppBase0):
pass


class PPCC0(PC0, m.CppDrvd0):
pass


class PCExplicitInitWithSuper0(m.CppBase0):
def __init__(self, value):
super().__init__(value + 1)


class PCExplicitInitMissingSuper0(m.CppBase0):
def __init__(self, value):
del value


class PCExplicitInitMissingSuperB0(m.CppBase0):
def __init__(self, value):
del value


#
# Using py::metaclass((PyObject *) &PyType_Type) (used with py::class_<> for CppBase1, CppDrvd1):
# COPY-PASTE block from above, replace 0 with 1:
#

class PC(m.CppBase):

class PC1(m.CppBase1):
pass


class PPCC(PC, m.CppDrvd):
class PPCC1(PC1, m.CppDrvd1):
pass


def test_PC():
d = PC(11)
class PCExplicitInitWithSuper1(m.CppBase1):
def __init__(self, value):
super().__init__(value + 1)


class PCExplicitInitMissingSuper1(m.CppBase1):
def __init__(self, value):
del value


class PCExplicitInitMissingSuperB1(m.CppBase1):
def __init__(self, value):
del value


@pytest.mark.parametrize(("pc_type"), [PC0, PC1])
def test_PC(pc_type):
d = pc_type(11)
assert d.get_base_value() == 11
d.reset_base_value(13)
assert d.get_base_value() == 13


def test_PPCC():
d = PPCC(11)
@pytest.mark.parametrize(("ppcc_type"), [PPCC0, PPCC1])
def test_PPCC(ppcc_type):
d = ppcc_type(11)
assert d.get_drvd_value() == 33
d.reset_drvd_value(55)
assert d.get_drvd_value() == 55
Expand All @@ -33,3 +85,54 @@ def test_PPCC():
d.reset_base_value_from_drvd(30)
assert d.get_base_value() == 30
assert d.get_base_value_from_drvd() == 30


@pytest.mark.parametrize(
("pc_type"), [PCExplicitInitWithSuper0, PCExplicitInitWithSuper1]
)
def testPCExplicitInitWithSuper(pc_type):
d = pc_type(14)
assert d.get_base_value() == 15


@pytest.mark.parametrize(
("derived_type"),
[
PCExplicitInitMissingSuper0,
PCExplicitInitMissingSuperB0,
PCExplicitInitMissingSuper1,
PCExplicitInitMissingSuperB1,
],
)
def testPCExplicitInitMissingSuper(derived_type):
if (
m.if_defined_PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS
and derived_type
in (
PCExplicitInitMissingSuper1,
PCExplicitInitMissingSuperB1,
)
):
pytest.skip(
"PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS is defined"
)
with pytest.raises(TypeError) as excinfo:
derived_type(0)
assert str(excinfo.value).endswith(
".__init__() must be called when overriding __init__"
)


def test_derived_tp_init_registry_weakref_based_cleanup():
def nested_function(i):
class NestedClass(m.CppBase0):
def __init__(self, value):
super().__init__(value + 3)

d1 = NestedClass(i + 7)
d2 = NestedClass(i + 8)
return (d1.get_base_value(), d2.get_base_value())

for _ in range(100):
assert nested_function(0) == (10, 11)
assert nested_function(3) == (13, 14)

0 comments on commit 54f8341

Please sign in to comment.