From 58d85e9e1d64ec7ef72e92c8875dc133f690a36a Mon Sep 17 00:00:00 2001 From: Jason Rhinelander Date: Tue, 16 May 2017 11:07:28 -0400 Subject: [PATCH] Override deduced Base class when defining Derived methods When defining method from a member function pointer (e.g. `.def("f", &Derived::f)`) we run into a problem if `&Derived::f` is actually implemented in some base class `Base` when `Base` isn't pybind-registered. This happens because the class type is deduced from the member function pointer, which then becomes a lambda with first argument this deduced type. For a base class implementation, the deduced type is `Base`, not `Derived`, and so we generate and registered an overload which takes a `Base *` as first argument. Trying to call this fails if `Base` isn't registered (e.g. because it's an implementation detail class that isn't intended to be exposed to Python) because the type caster for an unregistered type always fails. This commit adds a `method_adaptor` function that rebinds a member function to a derived type member function and otherwise (i.e. regular functions/lambda) leaves the argument as-is. This is now used for class definitions so that they are bound with type being registered rather than a potential base type. A closely related fix in this commit is to similarly update the lambdas used for `def_readwrite` (and related) to bind to the class type being registered rather than the deduced type so that registering a property that resolves to a base class member similarly generates a usable function. Fixes #854, #910. Co-Authored-By: Dean Moldovan --- include/pybind11/common.h | 5 ++++ include/pybind11/pybind11.h | 35 +++++++++++++++++----- tests/test_methods_and_attributes.cpp | 43 +++++++++++++++++++++++++-- tests/test_methods_and_attributes.py | 20 +++++++++++++ 4 files changed, 93 insertions(+), 10 deletions(-) diff --git a/include/pybind11/common.h b/include/pybind11/common.h index 1a5f7fa4e8..ed1bb610bc 100644 --- a/include/pybind11/common.h +++ b/include/pybind11/common.h @@ -620,6 +620,11 @@ using exactly_one_t = typename exactly_one::type; template struct deferred_type { using type = T; }; template using deferred_t = typename deferred_type::type; +/// Like is_base_of, but requires a strict base (i.e. `is_strict_base_of::value == false`, +/// unlike `std::is_base_of`) +template using is_strict_base_of = bool_constant< + std::is_base_of::value && !std::is_same::value>; + template class Base> struct is_template_base_of_impl { template static std::true_type check(Base *); diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index e922903675..6084d968dd 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -910,11 +910,22 @@ inline void call_operator_delete(void *p) { ::operator delete(p); } NAMESPACE_END(detail) +/// Given a pointer to a member function, cast it to its `Derived` version. +/// Forward everything else unchanged. +template +auto method_adaptor(F &&f) -> decltype(std::forward(f)) { return std::forward(f); } + +template +auto method_adaptor(Return (Class::*pmf)(Args...)) -> Return (Derived::*)(Args...) { return pmf; } + +template +auto method_adaptor(Return (Class::*pmf)(Args...) const) -> Return (Derived::*)(Args...) const { return pmf; } + template class class_ : public detail::generic_type { template using is_holder = detail::is_holder_type; - template using is_subtype = detail::bool_constant::value && !std::is_same::value>; - template using is_base = detail::bool_constant::value && !std::is_same::value>; + template using is_subtype = detail::is_strict_base_of; + template using is_base = detail::is_strict_base_of; // struct instead of using here to help MSVC: template struct is_valid_class_option : detail::any_of, is_subtype, is_base> {}; @@ -980,7 +991,7 @@ class class_ : public detail::generic_type { template class_ &def(const char *name_, Func&& f, const Extra&... extra) { - cpp_function cf(std::forward(f), name(name_), is_method(*this), + cpp_function cf(method_adaptor(std::forward(f)), name(name_), is_method(*this), sibling(getattr(*this, name_, none())), extra...); attr(cf.name()) = cf; return *this; @@ -1044,15 +1055,17 @@ class class_ : public detail::generic_type { template class_ &def_readwrite(const char *name, D C::*pm, const Extra&... extra) { - cpp_function fget([pm](const C &c) -> const D &{ return c.*pm; }, is_method(*this)), - fset([pm](C &c, const D &value) { c.*pm = value; }, is_method(*this)); + static_assert(std::is_base_of::value, "def_readwrite() requires a class member (or base class member)"); + cpp_function fget([pm](const type &c) -> const D &{ return c.*pm; }, is_method(*this)), + fset([pm](type &c, const D &value) { c.*pm = value; }, is_method(*this)); def_property(name, fget, fset, return_value_policy::reference_internal, extra...); return *this; } template class_ &def_readonly(const char *name, const D C::*pm, const Extra& ...extra) { - cpp_function fget([pm](const C &c) -> const D &{ return c.*pm; }, is_method(*this)); + static_assert(std::is_base_of::value, "def_readonly() requires a class member (or base class member)"); + cpp_function fget([pm](const type &c) -> const D &{ return c.*pm; }, is_method(*this)); def_property_readonly(name, fget, return_value_policy::reference_internal, extra...); return *this; } @@ -1075,7 +1088,8 @@ class class_ : public detail::generic_type { /// Uses return_value_policy::reference_internal by default template class_ &def_property_readonly(const char *name, const Getter &fget, const Extra& ...extra) { - return def_property_readonly(name, cpp_function(fget), return_value_policy::reference_internal, extra...); + return def_property_readonly(name, cpp_function(method_adaptor(fget)), + return_value_policy::reference_internal, extra...); } /// Uses cpp_function's return_value_policy by default @@ -1097,9 +1111,14 @@ class class_ : public detail::generic_type { } /// Uses return_value_policy::reference_internal by default + template + class_ &def_property(const char *name, const Getter &fget, const Setter &fset, const Extra& ...extra) { + return def_property(name, fget, cpp_function(method_adaptor(fset)), extra...); + } template class_ &def_property(const char *name, const Getter &fget, const cpp_function &fset, const Extra& ...extra) { - return def_property(name, cpp_function(fget), fset, return_value_policy::reference_internal, extra...); + return def_property(name, cpp_function(method_adaptor(fget)), fset, + return_value_policy::reference_internal, extra...); } /// Uses cpp_function's return_value_policy by default diff --git a/tests/test_methods_and_attributes.cpp b/tests/test_methods_and_attributes.cpp index 670f6c3b72..e8e1b52243 100644 --- a/tests/test_methods_and_attributes.cpp +++ b/tests/test_methods_and_attributes.cpp @@ -159,7 +159,7 @@ template <> struct type_caster { }; }} -/// Issue/PR #648: bad arg default debugging output +// Issue/PR #648: bad arg default debugging output class NotRegistered {}; // Test None-allowed py::arg argument policy @@ -177,6 +177,23 @@ struct StrIssue { StrIssue(int i) : val{i} {} }; +// Issues #854, #910: incompatible function args when member function/pointer is in unregistered base class +class UnregisteredBase { +public: + void do_nothing() const {} + void increase_value() { rw_value++; ro_value += 0.25; } + void set_int(int v) { rw_value = v; } + int get_int() const { return rw_value; } + double get_double() const { return ro_value; } + int rw_value = 42; + double ro_value = 1.25; +}; +class RegisteredDerived : public UnregisteredBase { +public: + using UnregisteredBase::UnregisteredBase; + double sum() const { return rw_value + ro_value; } +}; + test_initializer methods_and_attributes([](py::module &m) { py::class_ emna(m, "ExampleMandA"); emna.def(py::init<>()) @@ -325,7 +342,7 @@ test_initializer methods_and_attributes([](py::module &m) { m.def("ints_preferred", [](int i) { return i / 2; }, py::arg("i")); m.def("ints_only", [](int i) { return i / 2; }, py::arg("i").noconvert()); - /// Issue/PR #648: bad arg default debugging output + // Issue/PR #648: bad arg default debugging output #if !defined(NDEBUG) m.attr("debug_enabled") = true; #else @@ -360,4 +377,26 @@ test_initializer methods_and_attributes([](py::module &m) { .def("__str__", [](const StrIssue &si) { return "StrIssue[" + std::to_string(si.val) + "]"; } ); + + // Issues #854/910: incompatible function args when member function/pointer is in unregistered + // base class The methods and member pointers below actually resolve to members/pointers in + // UnregisteredBase; before this test/fix they would be registered via lambda with a first + // argument of an unregistered type, and thus uncallable. + py::class_(m, "RegisteredDerived") + .def(py::init<>()) + .def("do_nothing", &RegisteredDerived::do_nothing) + .def("increase_value", &RegisteredDerived::increase_value) + .def_readwrite("rw_value", &RegisteredDerived::rw_value) + .def_readonly("ro_value", &RegisteredDerived::ro_value) + // These should trigger a static_assert if uncommented + //.def_readwrite("fails", &SimpleValue::value) // should trigger a static_assert if uncommented + //.def_readonly("fails", &SimpleValue::value) // should trigger a static_assert if uncommented + .def_property("rw_value_prop", &RegisteredDerived::get_int, &RegisteredDerived::set_int) + .def_property_readonly("ro_value_prop", &RegisteredDerived::get_double) + // This one is in the registered class: + .def("sum", &RegisteredDerived::sum) + ; + + using Adapted = decltype(py::method_adaptor(&RegisteredDerived::do_nothing)); + static_assert(std::is_same::value, ""); }); diff --git a/tests/test_methods_and_attributes.py b/tests/test_methods_and_attributes.py index 95049cfd40..afe8d28680 100644 --- a/tests/test_methods_and_attributes.py +++ b/tests/test_methods_and_attributes.py @@ -457,3 +457,23 @@ def test_str_issue(msg): Invoked with: 'no', 'such', 'constructor' """ + + +def test_unregistered_base_implementations(): + from pybind11_tests import RegisteredDerived + + a = RegisteredDerived() + a.do_nothing() + assert a.rw_value == 42 + assert a.ro_value == 1.25 + a.rw_value += 5 + assert a.sum() == 48.25 + a.increase_value() + assert a.rw_value == 48 + assert a.ro_value == 1.5 + assert a.sum() == 49.5 + assert a.rw_value_prop == 48 + a.rw_value_prop += 1 + assert a.rw_value_prop == 49 + a.increase_value() + assert a.ro_value_prop == 1.75