Skip to content

Commit

Permalink
Add all_type_info_check_for_divergence() and some tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ralf W. Grosse-Kunstleve committed Nov 5, 2023
1 parent f3bb31e commit 0a9599f
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 0 deletions.
35 changes: 35 additions & 0 deletions include/pybind11/detail/type_caster_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,40 @@ inline void all_type_info_add_base_most_derived_first(std::vector<type_info *> &
bases.push_back(addl_base);
}

inline void all_type_info_check_for_divergence(const std::vector<type_info *> &bases) {
using sz_t = std::size_t;
sz_t n = bases.size();
if (n < 3) {
return;
}
std::vector<sz_t> cluster_ids;
cluster_ids.reserve(n);
for (sz_t ci = 0; ci < n; ci++) {
cluster_ids.push_back(ci);
}
for (sz_t i = 0; i < n - 1; i++) {
if (cluster_ids[i] != i) {
continue;
}
for (sz_t j = i + 1; j < n; j++) {
if (PyType_IsSubtype(bases[i]->type, bases[j]->type) != 0) {
sz_t k = cluster_ids[j];
if (k == j) {
cluster_ids[j] = i;
} else {
PyErr_Format(
PyExc_TypeError,
"bases include diverging derived types: base=%s, derived1=%s, derived2=%s",
bases[j]->type->tp_name,
bases[k]->type->tp_name,
bases[i]->type->tp_name);
throw error_already_set();
}
}
}
}
}

// Populates a just-created cache entry.
PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector<type_info *> &bases) {
assert(bases.empty());
Expand Down Expand Up @@ -168,6 +202,7 @@ PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector<type_
}
}
}
all_type_info_check_for_divergence(bases);
}

/**
Expand Down
19 changes: 19 additions & 0 deletions tests/test_python_multiple_inheritance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ struct CppDrvd : CppBase {
int drvd_value;
};

struct CppDrvd2 : CppBase {
explicit CppDrvd2(int value) : CppBase(value), drvd2_value(value * 5) {}
int get_drvd2_value() const { return drvd2_value; }
void reset_drvd2_value(int new_value) { drvd2_value = new_value; }

int get_base_value_from_drvd2() const { return get_base_value(); }
void reset_base_value_from_drvd2(int new_value) { reset_base_value(new_value); }

private:
int drvd2_value;
};

} // namespace test_python_multiple_inheritance

TEST_SUBMODULE(python_multiple_inheritance, m) {
Expand All @@ -42,4 +54,11 @@ TEST_SUBMODULE(python_multiple_inheritance, m) {
.def("reset_drvd_value", &CppDrvd::reset_drvd_value)
.def("get_base_value_from_drvd", &CppDrvd::get_base_value_from_drvd)
.def("reset_base_value_from_drvd", &CppDrvd::reset_base_value_from_drvd);

py::class_<CppDrvd2, CppBase>(m, "CppDrvd2")
.def(py::init<int>())
.def("get_drvd2_value", &CppDrvd2::get_drvd2_value)
.def("reset_drvd2_value", &CppDrvd2::reset_drvd2_value)
.def("get_base_value_from_drvd2", &CppDrvd2::get_base_value_from_drvd2)
.def("reset_base_value_from_drvd2", &CppDrvd2::reset_base_value_from_drvd2);
}
36 changes: 36 additions & 0 deletions tests/test_python_multiple_inheritance.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Adapted from:
# https://github.com/google/clif/blob/5718e4d0807fd3b6a8187dde140069120b81ecef/clif/testing/python/python_multiple_inheritance_test.py

import pytest

from pybind11_tests import python_multiple_inheritance as m


Expand All @@ -12,6 +14,22 @@ class PPCC(PC, m.CppDrvd):
pass


class PPPCCC(PPCC, m.CppDrvd2):
pass


class PC1(m.CppDrvd):
pass


class PC2(m.CppDrvd2):
pass


class PCD(PC1, PC2):
pass


def test_PC():
d = PC(11)
assert d.get_base_value() == 11
Expand All @@ -33,3 +51,21 @@ def test_PPCC():
d.reset_base_value_from_drvd(30)
assert d.get_base_value() == 30
assert d.get_base_value_from_drvd() == 30


def NOtest_PPPCCC():
# terminate called after throwing an instance of 'pybind11::error_already_set'
# what(): TypeError: bases include diverging derived types:
# base=pybind11_tests.python_multiple_inheritance.CppBase,
# derived1=pybind11_tests.python_multiple_inheritance.CppDrvd,
# derived2=pybind11_tests.python_multiple_inheritance.CppDrvd2
PPPCCC(11)


def test_PCD():
# This escapes all_type_info_check_for_divergence() because CppBase does not appear in bases.
with pytest.raises(
TypeError,
match=r"CppDrvd2\.__init__\(\) must be called when overriding __init__$",
):
PCD(11)

0 comments on commit 0a9599f

Please sign in to comment.