diff --git a/clif/python/gen.py b/clif/python/gen.py index 76afe2a5..f1d49a50 100644 --- a/clif/python/gen.py +++ b/clif/python/gen.py @@ -398,14 +398,24 @@ def TypeObject(ht_qualname, tracked_slot_groups, yield ( 'static int tp_init_impl(PyObject* self, PyObject* args, PyObject* kw);' ) + yield ( + 'static int tp_init_intercepted(' + 'PyObject* self, PyObject* args, PyObject* kw);' + ) if not iterator: yield '' - yield '// %s tp_alloc' % pyname + yield '// %s tp_alloc_impl' % pyname yield ( 'static PyObject* tp_alloc_impl(PyTypeObject* type, Py_ssize_t nitems);' ) tp_slots['tp_alloc'] = 'tp_alloc_impl' - tp_slots['tp_new'] = 'PyType_GenericNew' + yield '' + yield '// %s tp_new_impl' % pyname + yield ( + 'static PyObject* tp_new_impl(PyTypeObject* type, PyObject* args,' + ' PyObject* kwds);' + ) + tp_slots['tp_new'] = 'tp_new_impl' yield '' # Use dtor for dynamic types (derived) to wind down malloc'ed C++ obj, so # the C++ dtors are run. @@ -433,6 +443,7 @@ def TypeObject(ht_qualname, tracked_slot_groups, # Use delete for static types (not derived), allocated with tp_alloc_impl. tp_slots['tp_free'] = 'tp_free_impl' yield '' + yield '// %s tp_free_impl' % pyname yield 'static void tp_free_impl(void* self) {' yield I+'delete %s(self);' % _Cast(wname) yield '}' @@ -479,6 +490,16 @@ def TypeObject(ht_qualname, tracked_slot_groups, yield I+'return ty;' yield '}' if ctor: + yield '' + yield '// Intentionally leak the unordered_map:' + yield ( + '// https://google.github.io/styleguide/cppguide.html' + '#Static_and_Global_Variables' + ) + yield ( + 'static auto* derived_tp_init_registry = new std::unordered_map<' + 'PyTypeObject*, int(*)(PyObject*, PyObject*, PyObject*)>;' + ) yield '' yield ( 'static int tp_init_impl(' @@ -530,6 +551,28 @@ def TypeObject(ht_qualname, tracked_slot_groups, yield I+'Py_XDECREF(init);' yield I+'return init? 0: -1;' yield '}' + yield '' + yield ( + 'static int tp_init_intercepted(' + 'PyObject* self, PyObject* args, PyObject* kw) {' + ) + yield I+'DCHECK(PyType_Check(self) == 0);' + yield ( + I+'const auto derived_tp_init = ' + 'derived_tp_init_registry->find(Py_TYPE(self));' + ) + yield I+'CHECK(derived_tp_init != derived_tp_init_registry->end());' + yield I+'int status = (*derived_tp_init->second)(self, args, kw);' + yield I+'if (status == 0 &&' + yield I+' reinterpret_cast(self)->cpp.get() == nullptr) {' + yield I+' Py_DECREF(self);' + yield I+' PyErr_Format(PyExc_TypeError,' + yield I+' "%s.__init__() must be called when"' + yield I+' " overriding __init__", wrapper_Type->tp_name);' + yield I+' return -1;' + yield I+'}' + yield I+'return status;' + yield '}' if not iterator: yield '' yield ( @@ -543,6 +586,19 @@ def TypeObject(ht_qualname, tracked_slot_groups, yield I+'PyObject* self = %s(wobj);' % _Cast() yield I+'return PyObject_Init(self, %s);' % wtype yield '}' + yield '' + yield ( + 'static PyObject* tp_new_impl(PyTypeObject* type, PyObject* args,' + ' PyObject* kwds) {' + ) + if ctor: + yield I+'if (type->tp_init != tp_init_impl &&' + yield I+' derived_tp_init_registry->count(type) == 0) {' + yield I+I+'(*derived_tp_init_registry)[type] = type->tp_init;' + yield I+I+'type->tp_init = tp_init_intercepted;' + yield I+'}' + yield I+'return PyType_GenericNew(type, args, kwds);' + yield '}' def CreateInputParameter(func_name, ast_param, arg, args): diff --git a/clif/testing/python/python_multiple_inheritance_test.py b/clif/testing/python/python_multiple_inheritance_test.py index 47d81a96..74579e0f 100644 --- a/clif/testing/python/python_multiple_inheritance_test.py +++ b/clif/testing/python/python_multiple_inheritance_test.py @@ -13,6 +13,7 @@ # limitations under the License. from absl.testing import absltest +from absl.testing import parameterized from clif.testing.python import python_multiple_inheritance as tm @@ -32,7 +33,25 @@ def __init__(self, value): tm.CppDrvd.__init__(self, value + 1) -class PythonMultipleInheritanceTest(absltest.TestCase): +class PCExplicitInitWithSuper(tm.CppBase): + + def __init__(self, value): + super().__init__(value + 1) + + +class PCExplicitInitMissingSuper(tm.CppBase): + + def __init__(self, value): + del value + + +class PCExplicitInitMissingSuper2(tm.CppBase): + + def __init__(self, value): + del value + + +class PythonMultipleInheritanceTest(parameterized.TestCase): def testPC(self): d = PC(11) @@ -80,6 +99,22 @@ def testPPCCInit(self): self.assertEqual(d.get_base_value(), (30, 20)[i]) self.assertEqual(d.get_base_value_from_drvd(), 30) + def testPCExplicitInitWithSuper(self): + d = PCExplicitInitWithSuper(14) + self.assertEqual(d.get_base_value(), 15) + + @parameterized.parameters( + PCExplicitInitMissingSuper, PCExplicitInitMissingSuper2 + ) + def testPCExplicitInitMissingSuper(self, derived_type): + with self.assertRaises(TypeError) as ctx: + derived_type(0) + self.assertEndsWith( + str(ctx.exception), + "python_multiple_inheritance.CppBase.__init__() must be called when" + " overriding __init__", + ) + if __name__ == "__main__": absltest.main()